# Read perimeter data, build grid
# read demographic data
# read urban fragmentation
# read 2 iterations of synthetic partitions
# merge data
# read urban frontiers (store in the same gpkg in a different layer)
import os
import pandas as pd
import geopandas as gpd
from shapely.geometry import box
from tqdm import tqdm

CELL_SIZE = 100

dir_perimeter = '../data/processed/perimeter'
dir_pop = '../data/processed/demographics'
dir_frag = '../data/processed/urban_fragments'
dir_pval = '../data/processed/pval'
dir_regions = '../data/processed/regions'
dir_synthetic = '../data/processed/synthetic_partition'
dir_frontiers = '../data/processed/frontiers'
dir_summary = '../data/processed/summary'

list_countries = ['DEU','FRA','ITA','ESP','GBR','PRT','IRL','NLD']

def create_squares_from_centers(points):
    """
    Creates square polygons from the center points.
    This function takes a GeoDataFrame of points and generates square polygons 
    with a specified dimension (CELL_SIZE) centered at each point.
    Parameters:
    points (GeoDataFrame): A GeoDataFrame containing point geometries.
    Returns:
    GeoSeries: A GeoSeries containing square polygons centered at each point.
    """
    # A function that creates square polygon with dimension CELL_SIZE from the geometry column of points.
    return points.apply(lambda p: box(p.x - CELL_SIZE/2, p.y - CELL_SIZE/2, p.x + CELL_SIZE/2, p.y + CELL_SIZE/2))


for country in list_countries:
    list_fuas = os.listdir(os.path.join(dir_perimeter, country))
    list_fuas = [fua for fua in list_fuas if fua.endswith('.gpkg')]
    list_fuas = [fua.split('.gpkg')[0] for fua in list_fuas]
    for fua in tqdm(list_fuas):
        destination_folder = os.path.join(dir_summary, country)
        if not os.path.exists(destination_folder):
            os.makedirs(destination_folder)
        if os.path.exists(os.path.join(destination_folder, f'{fua}_summary.gpkg')):
            continue
        perimeter = gpd.read_file(os.path.join(dir_perimeter, country, f'{fua}.gpkg'))
        perimeter['geometry'] = create_squares_from_centers(perimeter['geometry'])
        demog = pd.read_parquet(os.path.join(dir_pop, country, f'{fua}.parquet'))
        pval = pd.read_parquet(os.path.join(dir_pval, country, f'pval_{fua}.parquet'))
        pval = pval.drop(columns=['var_coef','z']).rename(columns={'NOTEU':'mov_avg'})
        regions = pd.read_parquet(os.path.join(dir_regions, country, f'{fua}_reg.parquet'))
        frags = pd.read_csv(os.path.join(dir_frag, country, f'{fua}_fragments.csv'))
        synth = pd.read_parquet(os.path.join(dir_synthetic, country, f'{fua}_synthetic.parquet')).iloc[:, :2]
        frontiers = gpd.read_file(os.path.join(dir_frontiers, country, f'{fua}_line.gpkg'))
        full_data = perimeter.merge(demog, right_on='grid_ref',left_index=True)
        full_data = full_data.merge(pval, on='grid_ref')
        full_data = full_data.merge(regions[['grid_ref','seg']], on='grid_ref')
        full_data = full_data.merge(frags, on='grid_ref')
        full_data = full_data.merge(synth, left_on='grid_ref', right_index=True)
        full_data.to_file(os.path.join(dir_summary, country, f'{fua}_summary.gpkg'), layer='grid',driver='GPKG')
        frontiers.to_file(os.path.join(dir_summary, country, f'{fua}_summary.gpkg'), layer='frontiers',driver='GPKG')
        