import geopandas as gpd
from shapely.geometry import Point
from math import pi
import numpy as np
import pandas as pd
import os

dir_summary = '../data/processed/summary'
country_dict = {'DE':'DEU', 'FR':'FRA', 'IT':'ITA', 'ES':'ESP',
                'PT':'PRT', 'NL':'NLD', 'UK':'GBR','IE':'IRL'}
path_city_indicators = '../data/processed/city_indicators'

def entropy_index(mix, weight, mix_city=None):
    """
    Calculates the normalized entropy index for a given mix of categories and their associated weights.
    The entropy index is a measure of diversity or heterogeneity within a set of proportions (mix),
    weighted by the provided weights. Optionally, a city-level mix can be provided; otherwise, it is
    computed as the weighted average of the mix.
    Parameters
    ----------
    mix : array-like
        Array of proportions representing the mix of categories for each unit (e.g., neighborhood).
    weight : array-like
        Array of weights corresponding to each unit in `mix` (e.g., population or area).
    mix_city : float or None, optional
        The overall mix proportion for the city. If None, it is calculated as the weighted average of `mix`.
    Returns
    -------
    float
        The normalized entropy index, representing the relative diversity of the mix compared to the city-level mix.
    Notes
    -----
    - The function clips the mix values to avoid log(0) errors.
    - The entropy index is normalized by the city-level entropy.
    """
    if mix_city is None:
        # Calculate entropy index without share_migrants
        mix_city = (mix * weight).sum() / weight.sum()
    entropy_city = -(mix_city*np.log(mix_city) + (1-mix_city)*np.log(1-mix_city))
    entropy =  -(mix*np.log(np.clip(mix,0.000001,None)) + (1-mix)*np.log(np.clip(1-mix,0.00001,None)))
    entropy = ((entropy_city-entropy)*weight).sum()/(weight.sum()*entropy_city)
    return entropy

def entropy_index_region(df, col_mix, col_weight, col_group):
    """
    Calculates the entropy index for a given region, measuring the diversity or mix of a specified attribute within groups.
    The function computes the weighted entropy of a mixing variable (e.g., land use, population group) within subregions,
    and compares it to the overall entropy of the entire region. The result is a normalized measure of segregation or diversity.
    Parameters:
        df (pd.DataFrame): Input DataFrame containing the data.
        col_mix (str): Name of the column representing the mixing variable (values between 0 and 1).
        col_weight (str): Name of the column representing the weights (e.g., population, area).
        col_group (str): Name of the column representing the group or subregion.
    Returns:
        float: The normalized entropy index for the region, ranging from 0 (no segregation/diversity) to 1 (maximum segregation/diversity).
    """
    # Calculate entropy index
    mix_city = (df[col_mix] * df[col_weight]).sum()/df[col_weight].sum()
    entropy_city = -(mix_city*np.log(mix_city) + (1-mix_city)*np.log(1-mix_city))
    entropy = df.copy()
    entropy['weighted_mix'] = entropy[col_mix] * entropy[col_weight]
    entropy = entropy.groupby(col_group).sum().reset_index()
    entropy['mix'] = entropy['weighted_mix']/entropy[col_weight]
    entropy['entropy'] = -(entropy['mix']*np.log(np.clip(entropy['mix'],0.000001,None)) + (1-entropy['mix'])*np.log(np.clip(1-entropy['mix'],0.000001,None)))
    entropy = ((entropy_city-entropy['entropy'])*entropy[col_weight]).sum()/(entropy[col_weight].sum()*entropy_city)
    return entropy

def dissimilarity_index(mix, weight):
    """
    Calculate the dissimilarity index for segregation analysis.
    
    Parameters:
    -----------
    mix : array-like or pandas Series
        Proportion of group 1 in each geographic unit (between 0 and 1)
    weight : array-like or pandas Series
        Total population weight for each geographic unit
    
    Returns:
    --------
    float
        Dissimilarity index (between 0 and 1)
        - 0 indicates complete integration
        - 1 indicates complete segregation
    
    Notes:
    ------
    The dissimilarity index measures the percentage of either group that would
    need to move to achieve an even distribution across all geographic units.
    """

    # Remove any units with zero population to avoid division issues
    valid_mask = weight > 0
    populated_cells_mix = mix.loc[valid_mask].copy()
    populated_cells_weight = weight.loc[valid_mask].copy()
    
    # Calculate group populations in each unit
    group1_pop = populated_cells_mix * populated_cells_weight  # Population of group 1 in each unit
    group2_pop = (1 - populated_cells_mix) * populated_cells_weight  # Population of group 2 in each unit
    
    # Calculate total populations for each group
    total_group1 = group1_pop.sum()
    total_group2 = group2_pop.sum()
    
    # Handle edge cases
    if total_group1 == 0 or total_group2 == 0:
        return 1.0  # Complete segregation if one group is absent
    
    # Calculate proportions for each group in each unit
    prop_group1 = group1_pop / total_group1  # ti/T
    prop_group2 = group2_pop / total_group2  # pi/P
    
    # Calculate dissimilarity index
    dissimilarity = 0.5 * (prop_group1 - prop_group2).abs().sum()
    
    return dissimilarity

def main():
    if os.path.exists(os.path.join(path_city_indicators, 'city_indicators.gpkg')):
        results = gpd.read_file(os.path.join(path_city_indicators, 'city_indicators.gpkg'), layer='city_indicators')
    else:
        results = gpd.GeoDataFrame(columns=['fua','valid','total_pop','share_migrants','avg_density','frag_index',
                                            'center_of_mass_dist','center_of_mass_mig','center_of_mass_dist_norm',
                                            'entropy','entropy_region','geometry'],
                                geometry='geometry', crs='EPSG:3035')

    list_countries = ['DEU','FRA','ITA','ESP','GBR','PRT','IRL','NLD']

    for country in list_countries:
        list_fuas = os.listdir(os.path.join(dir_summary, country))
        list_fuas = [fua for fua in list_fuas if fua.endswith('.gpkg')]
        list_fuas = [fua.split('.gpkg')[0] for fua in list_fuas]
        for fua_file in list_fuas:
            fua = fua_file.split('_summary')[0]
            if fua in results['fua'].values:
                continue
            # Load the grid
            grid = gpd.read_file(f'{dir_summary}/{country_dict[fua[:2]]}/{fua}_summary.gpkg', layer='grid')
            # Valid demographic regions (at least one seg=1 and one seg=-1)
            seg = grid['seg'].value_counts()
            valid = (1 in seg.index) & (-1 in seg.index)
            # Total population
            total_pop = grid['pop'].sum()
            # Share of migrants in the total population
            share_migrants = (grid['pop']*grid['NOTEU']).sum()/total_pop
            # Average density (pop/km2, cells are 100m x 100m)
            avg_density = total_pop/(len(grid)*0.01)
            # Fragmentation index.
            frag_index = grid[['fragment_id','pop']].groupby('fragment_id').sum()/total_pop
            frag_index = 1-(frag_index**2).sum().values[0]
            # Segregation entropy index
            entropy = entropy_index(grid['mov_avg'], grid['pop'])
            # Segregation entropy index by region
            entropy_region = entropy_index_region(grid[['mov_avg', 'pop', 'seg']], 'mov_avg', 'pop', 'seg')
            dissimilarity = dissimilarity_index(grid['NOTEU'], grid['pop'])
            # Center of mass position
            center_of_mass = grid[['pop','geometry']].copy()
            center_of_mass['pop_x'] = center_of_mass['geometry'].centroid.x * center_of_mass['pop']
            center_of_mass['pop_y'] = center_of_mass['geometry'].centroid.y * center_of_mass['pop']
            center_of_mass = center_of_mass[['pop_x','pop_y']].sum()/total_pop
            center_of_mass = (center_of_mass['pop_x'], center_of_mass['pop_y'])
            center_of_mass = Point(center_of_mass)
            # population-weighted average distance to the center of mass
            center_of_mass_dist = grid[['pop','mov_avg','geometry']].copy()
            center_of_mass_dist['dist'] = center_of_mass_dist.distance(center_of_mass)
            center_of_mass_dist['pop_dist'] = center_of_mass_dist['pop']*center_of_mass_dist['dist']
            center_of_mass_dist['mig_dist'] = center_of_mass_dist['pop']*center_of_mass_dist['mov_avg']*center_of_mass_dist['dist']
            center_of_mass_dist_pop = center_of_mass_dist['pop_dist'].sum()/total_pop
            center_of_mass_dist_nonweighted = center_of_mass_dist['dist'].sum()/len(center_of_mass_dist)
            # average distance of migrants to the center of mass, normalized by average distance of the whole population
            tot_pop_mig = (center_of_mass_dist['pop']*center_of_mass_dist['mov_avg']).sum()
            center_of_mass_dist_mig = center_of_mass_dist['mig_dist'].sum()/tot_pop_mig/center_of_mass_dist_pop
            # average distance to the center of mass, normalized by city size
            av_dist_circle = (2/3*(len(grid)*10000/pi)**0.5)
            center_of_mass_dist_norm = center_of_mass_dist_pop/av_dist_circle
            center_of_mass_dist_nonweighted_norm = center_of_mass_dist_nonweighted/av_dist_circle
            # Compile results into a geodataframe
            results_fua = gpd.GeoDataFrame(data=[[fua, valid, total_pop, share_migrants, avg_density, frag_index,
                                                  center_of_mass_dist_pop,center_of_mass_dist_mig, center_of_mass_dist_norm,
                                                    center_of_mass_dist_nonweighted_norm,
                                                  entropy, entropy_region, dissimilarity]],
                                            columns=['fua','valid','total_pop','share_migrants','avg_density','frag_index',
                                                        'center_of_mass_dist','center_of_mass_mig','center_of_mass_dist_norm',
                                                        'center_of_mass_dist_nonweighted_norm',
                                                        'entropy','entropy_region', 'dissimilarity'],
                                            geometry=[center_of_mass], crs='EPSG:3035')
            results = pd.concat([results, results_fua],ignore_index=True)
        
    purity = pd.read_csv('../data/processed/purity/purity.csv')
    if 'fuaname' in results.columns:
        results = results.drop(columns='fuaname')
    results = results.merge(purity, on='fua')
    results['country'] = results['fua'].apply(lambda x: x[:2])
    results['valid'] = results['valid'].astype(bool)
    results['total_pop'] = results['total_pop'].astype(int)
    fuanames = pd.read_csv(os.path.join(path_city_indicators,'city_names.csv'))
    results = results.merge(fuanames, on='fua')
    results.to_file(os.path.join(path_city_indicators, 'city_indicators.gpkg'),
                    layer='city_indicators', driver='GPKG')
    results.drop(columns='geometry').to_csv(os.path.join(path_city_indicators,'city_indicators.csv'), index=False)
    return None
    
if __name__ == "__main__":
    main()