import geopandas as gpd
import pandas as pd
from tqdm import tqdm
from joblib import Parallel, delayed
import os
import pyarrow as pa
from time import time
MIN_DISTANCE = 100
MAX_DISTANCE = 2500
N_JOBS = 4  # Replacing MAX_WORKERS with N_JOBS for joblib
LEN_GRID_MAX = 10000
path_perimeter = '../data/processed/perimeter'
path_proximity = '../data/processed/proximity'

def compute_distances(grid, buf_grid):
    """
    Computes the distances between each cell in the grid and its neighboring cells.
    Parameters:
    grid (GeoDataFrame): A GeoDataFrame containing the grid cells with 'GRID_ID' and 'geometry' columns.
    buf_grid (GeoDataFrame): A GeoDataFrame containing the buffer grid cells with 'from_GRID_ID' and 'geometry' columns.
    Returns:
    DataFrame: A DataFrame containing the distances between each pair of neighboring cells with columns 'from_GRID_ID', 'to_GRID_ID', and 'distance'.
    """
    # Joining each cell with their neighbors, using the buffer grid
    buf_grid = gpd.sjoin(buf_grid, grid.set_index('GRID_ID'), how='inner', predicate='intersects', rsuffix='', lsuffix='')
    buf_grid = buf_grid.rename(columns={'GRID_ID':'to_GRID_ID'})
    
    # Splitting the table for the joined cells into two tables
    grid_to = grid[['GRID_ID','geometry']].merge(buf_grid[['from_GRID_ID','to_GRID_ID']], 
                                                left_on='GRID_ID', 
                                                right_on='to_GRID_ID').drop(columns='GRID_ID')
    grid_from = grid[['GRID_ID','geometry']].merge(buf_grid[['from_GRID_ID','to_GRID_ID']], 
                                                left_on='GRID_ID',
                                                right_on='from_GRID_ID').drop(columns='GRID_ID')

    # Ordering grid_to to have the same index order as grid_from
    grid_from = grid_from.set_index(['from_GRID_ID','to_GRID_ID'])
    grid_to = grid_to.set_index(['from_GRID_ID','to_GRID_ID'])
    grid_to = grid_to.reindex(grid_from.index)

    # Computing the distance between each pair of cells
    distances = grid_from.distance(grid_to['geometry']).reset_index(name='distance')
    return distances

def distance_table(grid):
    """
    Computes a distance table for a given grid.
    This function calculates the distances between grid cells and their 
    surrounding cells within a specified maximum distance. It supports 
    both sequential and parallel processing to handle large datasets 
    efficiently.
    Parameters:
    grid (GeoDataFrame): A GeoDataFrame containing the grid cells with 
                         geometries and a 'GRID_ID' column.
    Returns:
    DataFrame: A DataFrame containing the distances between grid cells.
    """
    grid_from = grid.copy()
    grid_from['geometry'] = grid_from.buffer(MAX_DISTANCE)
    grid_from = grid_from.rename(columns={'GRID_ID':'from_GRID_ID'})
    size_grid = grid_from.shape[0]
    
    # Create chunks for parallel processing
    chunks = [(i, min(i + LEN_GRID_MAX, size_grid)) 
              for i in range(0, size_grid, LEN_GRID_MAX)]
    
    if N_JOBS == 1:
        # Sequential processing
        distances = pd.DataFrame()
        for start, end in tqdm(chunks):
            distances_temp = compute_distances(grid, grid_from.iloc[start:end])
            distances = pd.concat([distances, distances_temp], ignore_index=True)
    else:
        # Parallel processing using joblib
        results = Parallel(n_jobs=N_JOBS, verbose=1)(
            delayed(compute_distances)(grid, grid_from.iloc[start:end])
            for start, end in chunks
        )
        distances = pd.concat(results, ignore_index=True)
    
    return distances

def compute_proximity(distances):
    """
    Compute the proximity values based on distances.
    This function modifies the 'distance' column in the input DataFrame by masking values 
    that are less than a predefined minimum distance (MIN_DISTANCE) and then calculates 
    the proximity values using the formula: MIN_DISTANCE^2 / distance^2. The 'distance' 
    column is then dropped, and the DataFrame columns are renamed and cast to appropriate types.
    Parameters:
    distances (pandas.DataFrame): A DataFrame containing at least the columns 'distance', 
                                  'from_GRID_ID', and 'to_GRID_ID'.
    Returns:
    pandas.DataFrame: A DataFrame with the proximity values and renamed columns.
    """
    distances['distance'] = distances['distance'].mask(distances['distance'] < MIN_DISTANCE, MIN_DISTANCE)
    distances['prox'] = MIN_DISTANCE**2 / distances['distance']**2
    distances = distances.drop(columns='distance')
    distances = distances.rename(columns={'from_GRID_ID':'from_grid_ref','to_GRID_ID':'to_grid_ref'})
    distances = distances.astype({'prox':'float32'})
    return distances

def main():
    list_countries = ['DEU','FRA','ITA','ESP','GBR','PRT','IRL','NLD']
    for country in list_countries:
        list_fuas = os.listdir(os.path.join(path_perimeter, country))
        list_fuas = [fua.split('.')[0] for fua in list_fuas if (fua.endswith('.gpkg') and fua != 'fuas.gpkg')]
        for fua in list_fuas:
            destination_path = os.path.join(path_proximity, country)
            if not os.path.exists(destination_path):
                os.makedirs(destination_path)
            if os.path.exists(os.path.join(destination_path,f'{fua}.parquet')):
                continue
            grid = gpd.read_file(os.path.join(path_perimeter, country,f'{fua}.gpkg')).reset_index(names='GRID_ID')
            if len(grid) < 2**16:
                int_type = pa.uint16()
                grid['GRID_ID'] = grid['GRID_ID'].astype('uint16')
            else:
                int_type = pa.uint32()
                grid['GRID_ID'] = grid['GRID_ID'].astype('uint32')
            distances = distance_table(grid)
            distances = compute_proximity(distances)
            schema = pa.schema([
                ('from_grid_ref', int_type),
                ('to_grid_ref', int_type),
                ('prox', pa.float32())
            ])
            distances.to_parquet(os.path.join(destination_path,f'{fua}.parquet'), 
                                 engine="pyarrow", compression="snappy",
                                 schema=schema)

if __name__ == '__main__':
    main()