# This script builds the perimeter of study per functional urban area (FUA) based on the demographic data of the FUA.
# We focus on urban cores. One FUA with multiple urban cores is split into multiple perimeters of study.
import pandas as pd
import geopandas as gpd
import os
import numpy as np
from tqdm import tqdm
from shapely.ops import transform
from pyproj import Transformer
from shapely.geometry import box
from joblib import Parallel, delayed

CELL_SIZE = 100
EXTRA_DISTANCE = 5
CRS = 'EPSG:3035'
MIN_DISTANCE = 405
MAX_DISTANCE = 2500
LEN_GRID_MAX = 10000
MIN_POP = 50000
DENSITY_THRESHOLD = 1000
MAX_WORKERS = 4
KEEP_COLS = False

FLOAT_PRECISION = 'float32'

path_raw_data = '../data/raw/demographics'
path_perimeter_study = '../data/processed/perimeter'
# path for the file registering the fuas done
# path_fua_done = 'fua_done.txt'
path_list_small_fuas = '../data/processed/perimeter/list_fuas/small_fuas.txt'

def build_convex_hull(cells):
    """
    Builds the convex hull of a set of cells and returns the bounding box and the convex hull geometry.
    Parameters:
    cells (pandas.DataFrame): A DataFrame containing cell data with a 'GRID_ID' column.
    Returns:
    tuple: A tuple containing:
        - bounds (tuple): A tuple representing the bounding box of the data in the format (min_x, min_y, max_x, max_y).
        - convex_hull (shapely.geometry.polygon.Polygon): The convex hull of the cells as a Shapely Polygon object.
    """
    cells = cells.copy()
    cells['x'] = cells['GRID_ID'].str.extract(r'N(.+)E').astype(int)
    cells['y'] = cells['GRID_ID'].str.extract(r'(.+)N').astype(int)
    # Filtering out misaligned cells
    cells = cells.loc[(cells['x']%100 == 50) & (cells['y']%100 == 50)].copy()

    # Get the bounding box of the data
    bounds = (cells['x'].min(), cells['y'].min(), cells['x'].max(), cells['y'].max())

    # Build the convex hull of the data
    cells = gpd.GeoDataFrame(cells,
                            geometry=gpd.points_from_xy(cells['x'], cells['y']),
                            crs=CRS)
    convex_hull = cells.union_all().convex_hull
    return bounds, convex_hull

def build_grid(bounds, cell_size, convex_hull):
    """
    Builds a grid of points within the specified bounds and filters out points outside the convex hull.
    Parameters:
    bounds (tuple): A tuple of the form (minx, miny, maxx, maxy) representing the bounding box.
    cell_size (float): The size of each cell in the grid.
    convex_hull (shapely.geometry.Polygon): A convex hull polygon used to filter out points outside it.
    Returns:
    geopandas.GeoDataFrame: A GeoDataFrame containing the grid points within the convex hull, with a 'GRID_ID' column.
    """
    minx, miny, maxx, maxy = bounds
    x_coords = np.arange(minx, maxx, cell_size)
    y_coords = np.arange(miny, maxy, cell_size)
    xx, yy = np.meshgrid(x_coords, y_coords)

    # Flatten the grid and create a DataFrame
    grid = pd.DataFrame({'x': xx.ravel(), 'y': yy.ravel()})

    # Convert to a GeoDataFrame
    grid = gpd.GeoDataFrame(
        grid,
        geometry=gpd.points_from_xy(grid['x'], grid['y']),
        crs=CRS  # Replace CRS if needed
    )

    # Filter out points outside the convex hull
    grid = grid[grid.within(convex_hull)].copy()
    
    grid['GRID_ID'] = grid['y'].astype(str) + 'N' + grid['x'].astype(str) + 'E'
    grid = grid.drop(columns=['x','y'])
    return grid

def fix_misaligned_demog(demog, cells):
    """
    Fix misaligned demographic data by redistributing population counts to adjacent grid cells.
    This function identifies demographic data entries that do not align with the given grid cells,
    redistributes their population counts to the four adjacent cells (top-left, top-right, bottom-left, bottom-right),
    and then merges the adjusted data back into the original demographic data.
    Parameters:
    demog (pd.DataFrame): DataFrame containing demographic data with columns 'GRID_ID' and 'pop'.
    cells (pd.DataFrame): DataFrame containing grid cell data with column 'GRID_ID'.
    Returns:
    pd.DataFrame: Updated demographic DataFrame with misaligned entries redistributed and merged.
    """
    misaligned_demog = demog.loc[~demog['GRID_ID'].isin(cells['GRID_ID'])].copy()
    misaligned_demog['x'] = misaligned_demog['GRID_ID'].str.extract(r'N(.+)E').astype(int)
    misaligned_demog['y'] = misaligned_demog['GRID_ID'].str.extract(r'(.+)N').astype(int)
    misaligned_demog = misaligned_demog[['x','y','pop']].copy()
    top_left = misaligned_demog.copy()
    top_right = misaligned_demog.copy()
    bottom_left = misaligned_demog.copy()
    bottom_right = misaligned_demog.copy()
    top_left['x'] = top_left['x'] - 50
    top_left['y'] = top_left['y'] + 50
    top_right['x'] = top_right['x'] + 50
    top_right['y'] = top_right['y'] + 50
    bottom_left['x'] = bottom_left['x'] - 50
    bottom_left['y'] = bottom_left['y'] - 50
    bottom_right['x'] = bottom_right['x'] + 50
    bottom_right['y'] = bottom_right['y'] - 50
    new_cells = pd.concat([top_left, top_right, bottom_left, bottom_right], ignore_index=True)
    new_cells['GRID_ID'] = new_cells['y'].astype(str) + 'N' +  new_cells['x'].astype(str) + 'E'
    new_cells['pop'] = new_cells['pop']/4
    new_cells = new_cells.drop(columns=['x','y'])
    demog = pd.concat([demog, new_cells], ignore_index=True)
    demog = demog.groupby('GRID_ID').sum().reset_index()
    return demog

def compute_mov_average(grid, buf_grid, min_distances):
    """
    Computes the moving average of population density within specified distances for each cell in a grid.
    Parameters:
    grid (GeoDataFrame): A GeoDataFrame containing the grid cells with 'grid_id', 'geometry', and 'pop' columns.
    buf_grid (GeoDataFrame): A GeoDataFrame containing the buffer zones around each grid cell with 'from_grid_id' and 'geometry' columns.
    min_distances (list or int): A list of minimum distances or a single minimum distance to compute the moving average.
    Returns:
    GeoDataFrame: A GeoDataFrame with 'grid_id' and the computed density columns for each specified distance.
    """
    if type(min_distances) is not list:
        min_distances = [min_distances]
    # Joining each cell with their neighbors, using the buffer grid
    buf_grid = gpd.sjoin(buf_grid.copy(),
                         grid[['grid_id','geometry']].copy().set_index('grid_id'), 
                         how='inner', predicate='intersects',rsuffix='',lsuffix='')
    buf_grid = buf_grid.rename(columns={'grid_id':'to_grid_id'})
    # Splitting the table for the joined cells into two tables, one containing the position of the origin
    # and the other the position of the destination cell.
    grid_to = grid[['grid_id','geometry']].merge(buf_grid[['from_grid_id','to_grid_id']], 
                                                 left_on='grid_id', 
                                                 right_on='to_grid_id').drop(columns='grid_id')
    grid_from = grid[['grid_id','geometry']].merge(buf_grid[['from_grid_id','to_grid_id']], 
                                                    left_on='grid_id',
                                                    right_on='from_grid_id').drop(columns='grid_id')
    
    # Ordering grid_to to have the same index order as grid_from
    grid_from = grid_from.set_index(['from_grid_id','to_grid_id'])
    grid_to = grid_to.set_index(['from_grid_id','to_grid_id'])
    grid_to = grid_to.reindex(grid_from.index)

    # Computing the distance between each pair of cells
    distances = grid_from.distance(grid_to['geometry']).reset_index(name='distance')
    distances['distance'] = distances['distance'].astype(FLOAT_PRECISION)

    weight_cols = ['weight_' + str(min_distance) for min_distance in min_distances]

    for col, min_distance in zip(weight_cols,min_distances):
        # Weight function
        distances[col] = np.ones(len(distances), dtype=FLOAT_PRECISION)
        distances[col] = distances[col].mask(distances['distance'] > min_distance,min_distance**2/distances['distance']**2)
    # Adding population data
    distances = distances.merge(grid[['grid_id','pop']], left_on='to_grid_id', right_on='grid_id').drop(columns='grid_id')

    # Spatially averaging the density
    density_cols = ['density_' + str(min_distance) for min_distance in min_distances]
    distances[density_cols] = distances[weight_cols].mul(distances['pop'], axis='index')

    density = distances[['from_grid_id'] + weight_cols + density_cols].groupby('from_grid_id').sum().reset_index()
    for min_distance in min_distances:
        density['density_' + str(min_distance)] = density['density_' + str(min_distance)]/density['weight_' + str(min_distance)]*100
    density = density.drop(columns=weight_cols)
    density = density.rename(columns={'from_grid_id':'grid_id'})
    return density.copy()

def compute_av_density(grid):
    """
    Compute the average population density for a given grid.

    This function calculates the average population density for each grid cell by 
    considering the population within a specified distance from each cell. It uses 
    parallel processing to speed up the computation if multiple workers are specified.

    Parameters:
    grid (pd.DataFrame): A DataFrame containing the grid data with at least the following columns:
        - 'grid_id': Unique identifier for each grid cell.
        - 'geometry': Geometric representation of each grid cell.
        - 'pop': Population count for each grid cell.

    Returns:
    pd.DataFrame: A DataFrame containing the computed average population density for each grid cell.
    """
    grid['pop'] = grid['pop'].fillna(0)
    grid_from = grid[['grid_id','geometry']].copy()
    # Prepare the result storage
    density = pd.DataFrame()
    grid_from['geometry'] = grid_from.buffer(MAX_DISTANCE)
    grid_from = grid_from.rename(columns={'grid_id':'from_grid_id'})
    size_grid = grid_from.shape[0]
    chunk_ranges = list(range(0, size_grid, LEN_GRID_MAX))
    if MAX_WORKERS == 1:
        for i in tqdm(chunk_ranges):
            end_boundaries = min(i + LEN_GRID_MAX, size_grid)
            density_temp = compute_mov_average(grid, grid_from.iloc[i:end_boundaries].copy(), MIN_DISTANCE)
            density = pd.concat([density, density_temp], ignore_index=True)
    else:
        results = Parallel(n_jobs=MAX_WORKERS)(
            delayed(compute_mov_average)(grid, grid_from.iloc[i:min(i + LEN_GRID_MAX, size_grid)].copy(), MIN_DISTANCE)
            for i in chunk_ranges
            )
        density = pd.concat(results, ignore_index=True)
    return density

def create_squares_from_centers(points):
    """
    Creates square polygons from the center points.
    This function takes a GeoDataFrame of points and generates square polygons 
    with a specified dimension (CELL_SIZE) centered at each point.
    Parameters:
    points (GeoDataFrame): A GeoDataFrame containing point geometries.
    Returns:
    GeoSeries: A GeoSeries containing square polygons centered at each point.
    """
    # A function that creates square polygon with dimension CELL_SIZE from the geometry column of points.
    x, y = points.x, points.y
    return points.apply(lambda p: box(p.x - CELL_SIZE/2 - EXTRA_DISTANCE,
                                      p.y - CELL_SIZE/2 - EXTRA_DISTANCE,
                                      p.x + CELL_SIZE/2 + EXTRA_DISTANCE,
                                      p.y + CELL_SIZE/2 + EXTRA_DISTANCE))

def find_high_density_regions(grid):
    """
    Identifies and processes high-density regions within a given grid.
    This function filters out low-density cells, creates urban regions, computes the population per urban region,
    and retains only large urban cores based on a minimum population threshold.
    Parameters:
    grid (GeoDataFrame): A GeoDataFrame containing the grid data with columns 'density_<MIN_DISTANCE>' and 'pop'.
    Returns:
    GeoDataFrame: A GeoDataFrame with high-density regions and their corresponding population data.
    """
    # Filtering out low density cells.
    grid = grid.loc[grid['density_' + str(MIN_DISTANCE)]>DENSITY_THRESHOLD].copy()
    # Creating urban regions.
    grid['geometry'] = create_squares_from_centers(grid['geometry'])
    high_density_regions = grid[['geometry']].dissolve().explode()
    high_density_regions.index = range(len(high_density_regions))
    grid = grid.sjoin(high_density_regions, rsuffix='region')
    # Computing the population per urban region.
    pop_regions = grid[['pop','index_region']].groupby(by='index_region').sum()
    pop_regions = pop_regions.sort_values(by='pop',ascending=False)
    pop_regions = pop_regions.reset_index()
    pop_regions['index_region_new'] = pop_regions.index.copy()
    # Keeping only large urban cores.
    pop_regions = pop_regions.loc[pop_regions['pop']>MIN_POP]
    grid = grid.merge(pop_regions[['index_region','index_region_new']], on='index_region')
    grid = grid.drop(columns='index_region')
    grid = grid.rename(columns={'index_region_new':'index_region'})
    #grid = grid.loc[grid['index_region'].isin(pop_regions['index_region'])].copy()
    grid = grid.drop(columns=['density_' + str(MIN_DISTANCE)])
    grid = grid.loc[grid['pop'] > 0].copy()
    return grid

def simplify_axes(grid):
    """
    Simplifies the axes of a given grid DataFrame by creating a new int column 'grid_id' 
    from the index and dropping the long str 'GRID_ID' column.
    Parameters:
    grid (pd.DataFrame): The input DataFrame containing grid data with a 'GRID_ID' column.
    Returns:
    tuple: A tuple containing:
        - pd.DataFrame: The modified grid DataFrame with 'GRID_ID' column removed.
        - pd.DataFrame: A DataFrame containing the original 'grid_id' and 'GRID_ID' columns.
    """
    grid['grid_id'] = grid.index.copy()
    if len(grid) < 2**16:
        grid['grid_id'] = grid['grid_id'].astype(np.uint16)
    elif len(grid) < 2**32:
        grid['grid_id'] = grid['grid_id'].astype(np.uint32)
    grid_axes = grid[['grid_id','GRID_ID']].copy()
    grid = grid.drop(columns='GRID_ID').copy()
    return grid, grid_axes

def extract_perimeter_study(cells, demog,path_country,fua):
    """
    Extracts and processes urban perimeter study data.
    Parameters:
    cells (GeoDataFrame): GeoDataFrame containing cell geometries.
    demog (DataFrame): DataFrame containing demographic information with 'GRID_ID' and 'pop' columns.
    path_country (str): Path to the directory where output files will be saved.
    fua (str): Functional Urban Area identifier.
    Returns:
    None: This function saves processed data to specified files.
    """
    bounds, convex_hull = build_convex_hull(cells)
    grid = build_grid(bounds, CELL_SIZE, convex_hull)
    # Summing population across all groups.
    demog = demog[['GRID_ID','pop']].groupby('GRID_ID').sum().reset_index()
    demog = fix_misaligned_demog(demog,cells)
    
    # Adding demographic information to the grid
    grid = grid.merge(demog, on='GRID_ID', how='left')
    
    grid, grid_axes = simplify_axes(grid)
    density = compute_av_density(grid)
    grid_buffer = grid.merge(density, on='grid_id')
    grid_buffer.to_file('temp.gpkg')
    grid_buffer = find_high_density_regions(grid_buffer)
    if grid_buffer.empty:
        print(f'Empty grid for {fua}')
        with open(os.path.join(path_list_small_fuas),'a') as f:
            f.write(f'{fua}\n')
    # Saving all grid cells belonging to each urban core in a separate file.
    for region in grid_buffer['index_region'].unique():
        temp = grid_buffer.loc[grid_buffer['index_region'] == region].copy()
        convex_hull = temp.union_all().convex_hull
        grid_temp = grid.loc[grid.within(convex_hull)].copy()
        grid_temp = grid_temp.reset_index(drop=True)
        grid_temp = grid_temp.merge(density, on='grid_id')
        if not KEEP_COLS:
            grid_temp = grid_temp[['grid_id','geometry']].copy()
        grid_temp = grid_temp.merge(grid_axes, on='grid_id').drop(columns='grid_id')
        grid_temp = grid_temp.reset_index(drop=True)
        grid_temp['GRID_ID'].to_csv(os.path.join(path_country, f'indices_{fua}_{region}.csv'), index=True)
        grid_temp.drop(columns='GRID_ID').to_file(os.path.join(path_country, f'{fua}_{region}.gpkg'))
        geodataframe_to_poly(grid_temp, os.path.join(path_country, f'{fua}_{region}.poly'), f'{fua}_{region}')
            

def retrieve_fuas_done():
    """
    Retrieve a list of FUAs (Functional Urban Areas) that have been processed.
    This function searches through all subdirectories of the specified 
    `path_perimeter_study` directory to find all files with a `.gpkg` extension.
    It then extracts the FUA identifier from the filenames and returns a list 
    of these identifiers.
    Returns:
        list: A list of FUA identifiers extracted from the filenames of 
              `.gpkg` files found in the directory and its subdirectories.
    """
    # find all .gpkg files in all subdirectories from path_perimeter_study
    list_files = []
    for root, dirs, files in os.walk(path_perimeter_study):
        for file in files:
            if file.endswith('.gpkg'):
                list_files.append(file.split('_')[0])
    return list_files

def geodataframe_to_poly(gdf, output_file, region_name):
    """
    Convert a GeoDataFrame geometry to .poly file format
    
    Parameters:
    -----------
    gdf : GeoDataFrame
        Input GeoDataFrame with a single polygon/multipolygon
    output_file : str
        Path to save the .poly file
    region_name : str
        Name of the region (first line in .poly file)
    """
    perimeter = gdf.geometry
    perimeter = perimeter.union_all().convex_hull.buffer(CELL_SIZE/2)
    
    # Transform to WGS84
    # Define the original CRS and target CRS
    transformer = Transformer.from_crs(CRS, "EPSG:4326", always_xy=True)
    # Apply the transformation directly to the geometry
    perimeter = transform(transformer.transform, perimeter)
    
    with open(output_file, 'w') as f:
        # Write region name
        f.write(f"{region_name}\n")
        f.write("perimeter\n")  # Section number
        
        # Write exterior ring coordinates
        exterior_coords = perimeter.exterior.coords[:-1]  # Exclude last point (same as first)
        for lon, lat in exterior_coords:
            f.write(f"    {lon:.6f}    {lat:.6f}\n")
        
        f.write("END\n")
        f.write("END\n")
        
    return

def main():
    path_demog = os.path.join(path_raw_data, f'GBR_L4.csv')
    path_grid_link = os.path.join(path_raw_data, f'grid_link_GBR.csv')
    country = 'GBR'
    demog = pd.read_csv(path_demog, dtype={'pop': 'float32'})
    cells = pd.read_csv(path_grid_link, usecols=['GRID_ID','nuts_id'],dtype={'nuts_id': 'str'})
    cells = cells.loc[cells['nuts_id'].notna()].copy()
    # Filtering out large nuts sparsely populated.
    cells = cells.loc[cells['nuts_id']!='UKN05'].copy()
    # Combining Belfast with Outer Belfast
    cells['nuts_id'] = cells['nuts_id'].mask(cells['nuts_id'] == 'UKN02','UKN01').copy()
    path_country = os.path.join(path_perimeter_study,country)
    list_nuts = cells.loc[cells['nuts_id'].str.contains('UKN')]['nuts_id'].unique()
    
    for fua in list_nuts:
        print(f'Processing {fua}')
        cells_fua = cells.loc[cells['nuts_id'] == fua].copy()
        demog_fua = demog.loc[demog['GRID_ID'].isin(cells_fua['GRID_ID'])].copy()
        extract_perimeter_study(cells_fua,demog_fua,path_country,fua)
    
if __name__ == '__main__':
    main()