# This function generate synthetic fragments within a given perimeter of study.
import gstools as gs
import numpy as np
import geopandas as gpd
from shapely.geometry import box
import os
from joblib import Parallel, delayed
import pandas as pd
from tqdm import tqdm
from scipy.stats import wasserstein_distance

CELL_SIZE = 100
dir_perimeter = '../data/processed/perimeter'
dir_synthetic = '../data/processed/synthetic_partition'
dir_pop = '../data/processed/demographics'
dir_frag = '../data/processed/urban_fragments'
N_WORKERS = 6
N_FIELDS_OPTIMIZATION = 500
N_FIELDS_SYNTHETIC = 500
N_BEST_IT = 200
DEBUG = False
FIELD_VARIANCE = 1
K_MIN = 1
K_MAX = 15
RADIUS_K = (K_MAX - K_MIN) / N_FIELDS_OPTIMIZATION * 10
LEN_MIN = 500
LEN_MAX = 1500
RADIUS_LEN = (LEN_MAX - LEN_MIN) /N_FIELDS_OPTIMIZATION * 10

def data_preprocessing(country,fua):
    """
    Preprocesses spatial and population data for a given country and functional urban area (FUA).
    Parameters:
    country (str): The name of the country.
    fua (str): The functional urban area identifier.
    Returns:
    tuple: A tuple containing:
        - cells_pop (GeoDataFrame): A GeoDataFrame with spatial data and population information for cells with population greater than 0.
        - size_dist (DataFrame): A DataFrame containing the rank-size distribution of urban fragments.
    """
    
    cells = gpd.read_file(os.path.join(dir_perimeter,country,f'{fua}.gpkg'))
    cells['grid_ref'] = cells.index.copy()
    pop = pd.read_parquet(os.path.join(dir_pop, country, f'{fua}.parquet'),columns=['grid_ref','pop'])
    
    cells_pop = cells.merge(pop, on='grid_ref', how='inner')
    cells_pop = cells_pop.loc[cells_pop['pop'] > 0].copy()
    size_dist = pd.read_csv(os.path.join(dir_frag,country,f'{fua}_fragments.csv'),
                            usecols=['fragment_id']).value_counts().reset_index(drop=True)
    
    return cells_pop, size_dist

def field_calculator_worker(grid, list_k, list_length):
    """
    Calculate synthetic fields and their probabilities for a given grid.
    Parameters:
    grid (GeoDataFrame): A GeoDataFrame containing the grid points with centroid coordinates.
    list_k (list of float): A list of power transformation exponents.
    list_length (list of float): A list of length scales for the Gaussian model.
    Returns:
    numpy.ndarray: A 2D array where each row corresponds to the normalized probabilities 
                   of the synthetic field for each k in list_k.
    """
    n = len(list_k)
    # Step 1: Prepare storage for probabilities
    probabilities = np.zeros((n, len(grid)))

    # Step 2: Generate synthetic fields for n_per_worker iterations
    for i, k in enumerate(list_k):
        # Step 2: Create the Gaussian model
        model = gs.Gaussian(dim=2, var=FIELD_VARIANCE, len_scale=list_length[i])
    
        srf = gs.SRF(model)  # Generate a Gaussian random field
        field = srf([grid.centroid.x, grid.centroid.y])  # Evaluate the field on grid points
        field = field + np.absolute(np.min(field))  # Shift to positive values
        field = np.power(field, k)  # Apply the power transformation
        probabilities[i] = field / np.sum(field)  # Normalize probabilities
    
    # Step 4: Return the probabilities for this worker
    return probabilities

def field_calculator(grid, list_k=None, k=None, list_length_scale=None, length_scale=None, n_fields=N_FIELDS_SYNTHETIC):
    """
    Calculate synthetic fields based on given parameters and grid.
    Parameters:
    grid (np.ndarray): The grid on which the fields are calculated.
    list_k (list or np.ndarray, optional): List of k values for each field. Default is None.
    k (float, optional): Single k value to be used for all fields. Default is None.
    list_length_scale (list or np.ndarray, optional): List of length scale values for each field. Default is None.
    length_scale (float, optional): Single length scale value to be used for all fields. Default is None.
    n_fields (int, optional): Number of fields to generate. Default is N_FIELDS_SYNTHETIC.
    Returns:
    np.ndarray: A 2D array where each row corresponds to a generated field.
    Raises:
    ValueError: If neither or both of list_k and k are provided.
    ValueError: If neither or both of list_length_scale and length_scale are provided.
    """
    if list_k is None and k is None:
        raise ValueError("Either list_k or k must be provided")
    elif list_k is not None and k is not None:
        raise ValueError("Only one of list_k or k must be provided")
    
    if k is not None:
        list_k = np.full(n_fields, k)
    else:
        list_k = np.asarray(list_k)
    
    if list_length_scale is None and length_scale is None:
        raise ValueError("Either list_length_scale or length_scale must be provided")
    elif list_length_scale is not None and length_scale is not None:
        raise ValueError("Only one of list_length_scale or length_scale must be provided")

    if length_scale is not None:
        list_length_scale = np.full(n_fields, length_scale)
    else:
        list_length_scale = np.asarray(list_length_scale)
        
    # Step 1: Determine number of workers and iterations per worker
    num_workers = N_WORKERS
    n_per_worker = n_fields // num_workers  # Base number of iterations per worker
    remainder = n_fields % num_workers     # Handle cases where n is not divisible by num_workers
    
    start_indices = [i * n_per_worker + min(i, remainder) for i in range(num_workers)]

    # Step 2: Dispatch tasks to workers
    results = Parallel(n_jobs=N_WORKERS)(
        delayed(field_calculator_worker)(
            grid, 
            list_k[start:start + n_per_worker + (1 if i < remainder else 0)],
            list_length_scale[start:start + n_per_worker + (1 if i < remainder else 0)]
        ) for i, start in enumerate(start_indices)
    )

    # Step 3: Combine results from all workers
    return np.vstack(results)  # Stack results into a single array


def sampling_worker(n_frag, fields, start_idx):
    """
    Samples indices from given fields and returns a 2D array with iteration indices and sampled indices.
    Parameters:
    n_frag (int): Number of fragments to sample from each field.
    fields (numpy.ndarray): 2D array where each row represents a field and each column represents a probability.
    start_idx (int): Starting index to be added to the iteration index for each field.
    Returns:
    numpy.ndarray: 2D array where each row contains an iteration index and a sampled index.
    """
    all_sample_indices = []
    for field_id in range(fields.shape[0]):
        sampled_indices = np.random.choice(fields.shape[1], size=n_frag, replace=False, p=fields[field_id])
        # Create a 2D array with iteration index and sampled indices
        field_indices = np.column_stack((np.full(n_frag, start_idx + field_id), sampled_indices))
        all_sample_indices.append(field_indices)
    return np.vstack(all_sample_indices)

def sampling(n_frag, fields):
    """
    Distributes the sampling task among multiple workers and combines the results.
    Parameters:
    n_frag (int): The number of fragments to sample.
    fields (list): A list of fields to be processed.
    Returns:
    numpy.ndarray: A 2D array containing the combined results from all workers.
    """
    num_workers = N_WORKERS
    n_fields = fields.shape[0]
    n_per_worker = n_fields // num_workers
    remainder = n_fields % num_workers
    
    start_indices = [i * n_per_worker + min(i,remainder) for i in range(num_workers)]
    
    results = Parallel(n_jobs=N_WORKERS)(
        delayed(sampling_worker)(
            n_frag, 
            fields[start:start + n_per_worker + (1 if i < remainder else 0)], 
            start
        ) for i, start in enumerate(start_indices)
    )
    return np.vstack(results)

def find_nearest_worker(cells, points):
    """
    Finds the nearest worker for each cell in the given DataFrame.
    Parameters:
    cells (GeoDataFrame): A GeoDataFrame containing the cells to be analyzed.
    points (GeoDataFrame): A GeoDataFrame containing the points representing workers, with an 'iteration_id' column.
    Returns:
    DataFrame: A DataFrame containing the nearest worker for each cell, with columns 'grid_ref', 'fragment_id', and 'iteration_id'.
    """
    list_iterations = points['iteration_id'].drop_duplicates()
    synth_frag = pd.DataFrame(columns=['grid_ref', 'fragment_id', 'iteration_id'])
    for it in list_iterations:
        points_it = points[points['iteration_id'] == it].copy()
        synth_frag_it = cells.sjoin_nearest(points_it).drop(columns=['index_right', 'geometry'])
        synth_frag_it['iteration_id'] = it
        synth_frag = pd.concat([synth_frag, synth_frag_it],ignore_index=True)
    return synth_frag


def find_nearest(grid, points):
    """
    Distributes the task of finding the nearest points in a grid among multiple workers and 
    combines the results into a single DataFrame.
    Args:
        grid (pd.DataFrame): The grid DataFrame containing the reference points.
        points (pd.DataFrame): The points DataFrame containing the points to be matched with the grid.
            Must contain a column 'iteration_id' to identify different iterations.
    Returns:
        pd.DataFrame: A DataFrame with the nearest points, pivoted to have 'iteration_id' as columns 
        and 'fragment_id' as values. The index of the DataFrame is 'grid_ref'.
    """
    num_workers = N_WORKERS
    n_per_worker = points['iteration_id'].nunique() // num_workers
    remainder = points['iteration_id'].nunique() % num_workers
    
    start_indices = [i * n_per_worker + min(i,remainder) for i in range(num_workers)]
    
    synthetic_frag = Parallel(n_jobs=N_WORKERS)(
        delayed(find_nearest_worker)(
            grid, 
            points[points['iteration_id'].isin(range(start, start + n_per_worker + (1 if i < remainder else 0)))]
        ) for i, start in enumerate(start_indices)
    )
    synthetic_frag = pd.concat(synthetic_frag, ignore_index=True)
    # pivot the table to have the iteration_id as columns and fragment_id as values
    synthetic_frag = synthetic_frag.pivot(index='grid_ref', columns='iteration_id', values='fragment_id').reset_index()
    synthetic_frag.columns.name = None
    return synthetic_frag

def distance_based_smoothing_2d(data, k_values, length_scale_values):
    """
    Apply distance-based smoothing to 2D data using specified k-values and length-scale values.
    Parameters:
    data (array-like): The input 2D data to be smoothed.
    k_values (array-like): The k-values corresponding to the data points.
    length_scale_values (array-like): The length-scale values corresponding to the data points.
    Returns:
    np.ndarray: The smoothed 2D data.
    """
    smoothed_data = np.zeros_like(data, dtype=float)
    data = np.asarray(data)
    
    for i, (k, length) in enumerate(zip(k_values, length_scale_values)):
        # Find all points within radius in both dimensions
        mask = (np.abs(k_values - k) <= RADIUS_K) & (np.abs(length_scale_values - length) <= RADIUS_LEN)
        
        # If no points are within radius (shouldn't happen but just in case)
        if not np.any(mask):
            smoothed_data[i] = data[i]
            continue
        
        # Calculate average of all points within radius
        smoothed_data[i] = np.median(data[mask])
    
    return smoothed_data

def generate_synthetic(grid, n_frag, k=None, list_k=None, length_scale=None, list_length_scale=None, n_fields=N_FIELDS_SYNTHETIC, debug=False):
    """
    Generate synthetic fragments based on the provided grid and parameters.
    Parameters:
    grid (GeoDataFrame): The input grid containing geometries.
    n_frag (int): Number of fragments to generate.
    k (int, optional): A single value for the field calculation parameter k.
    list_k (list of int, optional): A list of values for the field calculation parameter k.
    length_scale (float, optional): A single value for the field calculation parameter length_scale.
    list_length_scale (list of float, optional): A list of values for the field calculation parameter length_scale.
    n_fields (int, optional): Number of fields to generate. Default is N_FIELDS_SYNTHETIC.
    debug (bool, optional): If True, return additional debug information. Default is False.
    Returns:
    GeoDataFrame: A GeoDataFrame containing the synthetic fragments.
    If debug is True, returns a tuple (fields, random_points, synthetic_frag).
    Raises:
    ValueError: If both k and list_k are None or both are provided.
    ValueError: If both length_scale and list_length_scale are None or both are provided.
    """

    if k is None and list_k is None:
        raise ValueError("Either k or list_k must be provided")
    if k is not None and list_k is not None:
        raise ValueError("Only one of k or list_k must be provided")
    if length_scale is None and list_length_scale is None:
        raise ValueError("Either length_scale or list_length_scale must be provided")
    if length_scale is not None and list_length_scale is not None:
        raise ValueError("Only one of length_scale or list_length_scale must be provided")
    
    if k is not None and length_scale is not None:
        fields = field_calculator(grid, k=k, length_scale=length_scale, n_fields=n_fields)
    elif k is not None and list_length_scale is not None:
        fields = field_calculator(grid, k=k, list_length_scale=list_length_scale, n_fields=n_fields)
    elif list_k is not None and length_scale is not None:
        fields = field_calculator(grid, list_k=list_k, length_scale=length_scale, n_fields=n_fields)
    else:
        fields = field_calculator(grid, list_k=list_k, list_length_scale=list_length_scale, n_fields=n_fields)
        
    sampled_indices = sampling(n_frag, fields)

    random_points_x = grid.iloc[sampled_indices[:, 1]]['geometry'].x.copy()
    random_points_y = grid.iloc[sampled_indices[:, 1]]['geometry'].y.copy()
    # Generate random shifts for x and y
    random_points_x = random_points_x + np.random.uniform(-CELL_SIZE/2, CELL_SIZE/2, size=len(random_points_x))  # Adjust range as needed
    random_points_y = random_points_y + np.random.uniform(-CELL_SIZE/2, CELL_SIZE/2, size=len(random_points_y))
    random_points = gpd.GeoDataFrame(geometry=gpd.points_from_xy(random_points_x, random_points_y, crs=grid.crs))
    
    random_points['iteration_id'] = sampled_indices[:, 0]
    random_points['fragment_id'] = range(len(random_points)) - random_points['iteration_id'] * n_frag
    synthetic_frag = find_nearest(grid, random_points)
    if debug:
        return fields, random_points, synthetic_frag
    else:
        return synthetic_frag


def find_optimal_parameters(actual_sizes, grid, k_min=K_MIN, k_max=K_MAX, len_min=LEN_MIN, len_max=LEN_MAX, n_fields=N_FIELDS_OPTIMIZATION):
    """
    Find the optimal parameters for generating synthetic fragmentations that best match the actual fragment sizes.
    Parameters:
    actual_sizes (array-like): The actual fragment sizes to match.
    grid (array-like): The grid on which to generate synthetic fragmentations.
    k_min (float, optional): The minimum value for the parameter k. Default is K_MIN.
    k_max (float, optional): The maximum value for the parameter k. Default is K_MAX.
    len_min (float, optional): The minimum value for the length scale parameter. Default is LEN_MIN.
    len_max (float, optional): The maximum value for the length scale parameter. Default is LEN_MAX.
    n_fields (int, optional): The number of fields to generate for optimization. Default is N_FIELDS_OPTIMIZATION.
    Returns:
    tuple: A tuple containing the optimal k value, the optimal length scale value, and a DataFrame with the results.
    """
    n_frag = len(actual_sizes)
    
    # Generate random parameter pairs
    if k_min == k_max:
        k_values = np.full(n_fields, k_min)
    else:
        k_values = np.random.uniform(k_min, k_max, n_fields)
    if len_min == len_max:
        length_scale_values = np.full(n_fields, len_min)
    else:
        length_scale_values = np.random.uniform(len_min, len_max, n_fields)
    
    # Generate synthetic fragmentations for all parameter pairs
    synthetic_data = generate_synthetic(grid, n_frag, list_k=k_values, list_length_scale=length_scale_values, n_fields=n_fields)
    
    
    # Calculate Wasserstein distances for each parameter pair
    wasserstein_distances = []
    for col in range(n_fields):
        synthetic_sizes = synthetic_data.loc[:,col].value_counts().values
        distance = wasserstein_distance(actual_sizes, synthetic_sizes)
        wasserstein_distances.append(distance)
    
    # Apply 2D distance-based smoothing
    smoothed_distances = distance_based_smoothing_2d(wasserstein_distances, k_values, length_scale_values)
    
    # Find optimal parameter pair
    optimal_idx = np.argmin(smoothed_distances)
    optimal_k = k_values[optimal_idx]
    optimal_length_scale = length_scale_values[optimal_idx]
    
    # Create results DataFrame
    results = pd.DataFrame({
        'k': k_values, 
        'length_scale': length_scale_values,
        'distance': wasserstein_distances,
        'smoothed_distance': smoothed_distances
    })
    
    return optimal_k, optimal_length_scale, results
    

def main():
    list_countries = ['DEU','FRA','ITA','ESP','GBR','PRT','IRL','NLD']
    for country in list_countries:

        list_fuas = os.listdir(os.path.join(dir_perimeter, country))
        # Extract gpkg filenames
        list_fuas = [fua.split('.')[0] for fua in list_fuas if fua.endswith('.gpkg')]
        if not os.path.exists(os.path.join(dir_synthetic, country)):
            os.makedirs(os.path.join(dir_synthetic, country))

        for fua in tqdm(list_fuas):
            
            if os.path.exists(os.path.join(dir_synthetic, country, f'{fua}_synthetic.parquet')) and not DEBUG:
                continue
            
            if DEBUG:
                fua = 'DE027_0'
                
            try:
                cells_pop, size_dist = data_preprocessing(country,fua)
                n_frag = len(size_dist)
            except:
                print(f'No fragments found for {country}/{fua}')
                continue
            
            k_optimal, optimal_length_scale, optimization_results = find_optimal_parameters(size_dist, cells_pop, n_fields=N_FIELDS_OPTIMIZATION)
            optimization_results.to_csv(os.path.join(dir_synthetic, country, f'{fua}_optimization.csv'), index=False)
            synthetic_frag = generate_synthetic(cells_pop, n_frag, k=k_optimal,length_scale=optimal_length_scale, n_fields=N_FIELDS_SYNTHETIC)
            actual_sizes = size_dist.values
            distance_df = pd.DataFrame(columns=['iteration_id', 'distance'])
            distance_df = distance_df.astype({'iteration_id': int, 'distance': float})
            synthetic_frag = synthetic_frag.set_index('grid_ref')
            for col in synthetic_frag.columns:
                synthetic_sizes = synthetic_frag.loc[:,col].value_counts().values
                distance = wasserstein_distance(actual_sizes, synthetic_sizes)
                distance_df = pd.concat([distance_df, pd.DataFrame({'iteration_id': [col], 'distance': [distance]})], ignore_index=True)
            if N_BEST_IT > 0:
                best_its = distance_df.nsmallest(N_BEST_IT, 'distance')['iteration_id']
                synthetic_frag = synthetic_frag.loc[:,best_its].copy()
            synthetic_frag.columns = synthetic_frag.columns.map(str)
            synthetic_frag.to_parquet(os.path.join(dir_synthetic, country, f'{fua}_synthetic.parquet'), engine='pyarrow', compression='snappy')
            if DEBUG:
                break
        if DEBUG:
            break
            
    return

if __name__ == "__main__":
    main()