# This script builds the covariance matrix of the demographic variable.
import pandas as pd
import os
from scipy import sparse
from tqdm import tqdm
import re

MAX_DISTANCE = 2500
MIN_DISTANCE = 100
FLOAT_PRECISION = 'float32'

test = False
dir_demog = '../data/processed/demographics'
dir_proximity = '../data/processed/proximity'
dir_cov_mat = '../data/processed/covariance_matrix'
dir_perimeter = '../data/processed/perimeter'

def compute_cov_mat(distances, col_inh, col_var):
    """
    Compute the covariance matrix based on distances, inhabitant counts, and variances.
    Parameters:
    distances (pd.DataFrame): DataFrame containing distance information with columns 'from_grid_ref', 'to_grid_ref', and 'prox'.
    col_inh (pd.Series): Series containing inhabitant counts indexed by 'from_grid_ref'.
    col_var (pd.Series): Series containing variance values indexed by 'to_grid_ref'.
    Returns:
    scipy.sparse.csr_matrix: The computed covariance matrix as a sparse matrix.
    """
    # Weights calculation
    wgt = distances.join(col_inh, on='from_grid_ref', how='inner')
    wgt['wgt'] = wgt[col_inh.name] * wgt['prox']
    
    # Calculate denominators
    wgt_denom = wgt[['to_grid_ref','wgt']].groupby(by='to_grid_ref').sum().reset_index()
    wgt = wgt.merge(wgt_denom.rename(columns={'wgt':'denom'}), on='to_grid_ref')
    
    # Normalize weights
    wgt['wgt'] = wgt['wgt'] / wgt['denom']
    wgt['wgt'] = wgt['wgt'].mask(wgt['denom']==0, 0)
    
    n = len(wgt['from_grid_ref'].unique())
    
    # Convert weight DataFrame to sparse matrix directly
    row_idx = wgt['from_grid_ref'].values
    col_idx = wgt['to_grid_ref'].values
    
    # Create sparse weight matrix
    W = sparse.csr_matrix((wgt['wgt'], (row_idx, col_idx)), shape=(n, n), dtype=FLOAT_PRECISION)
    
    # Create sparse variance vector
    V = sparse.diags(col_var.fillna(0).values)
    
    # Compute covariance matrix using sparse operations
    cov_mat = (W.T @ (W @ V))
    
    return cov_mat

def theoretical_variance_resmix(grid, var_coef, city_statistics, fua):
    """
    Calculate the theoretical variance for a given Functional Urban Area (FUA) 
    by merging grid data with variance coefficients and city statistics.
    Parameters:
    grid (pd.DataFrame): DataFrame containing grid data with 'GRID_ID' column.
    var_coef (pd.DataFrame): DataFrame containing variance coefficients with 'GRID_ID' and 'var_coef' columns.
    city_statistics (pd.DataFrame): DataFrame containing city statistics with 'city' and 'var' columns.
    fua (str): The Functional Urban Area (FUA) identifier.
    Returns:
    pd.DataFrame: DataFrame with the variance coefficients for the specified FUA, 
                    where the 'var_coef' column is scaled by the city's variance.
    """
    var_coef_fua = grid.merge(var_coef, on = 'GRID_ID', how = 'left')
    var_coef_fua = var_coef_fua.fillna(0)
    var_fua = city_statistics.loc[city_statistics['city'] == fua,'var'].values[0]
    var_coef_fua = var_coef_fua.drop(columns = 'GRID_ID')
    var_coef_fua['var_coef'] = var_coef_fua['var_coef'] * var_fua
    return var_coef_fua

def main():
    city_statistics = pd.read_csv(os.path.join(dir_demog,'city_statistics.csv'))
    list_countries = ['DEU','FRA','ITA','ESP','GBR','PRT','IRL','NLD']
    
    for country in list_countries:

        var_coef = pd.read_parquet(os.path.join(dir_demog,country,f'var_coef_{country}.parquet'))
        list_fuas = os.listdir(os.path.join(dir_perimeter, country))
        # Extract substrings
        pattern = r"indices_(.+?)\.csv"
        list_fuas = [re.search(pattern, s).group(1) for s in list_fuas if re.search(pattern, s)]
        # Loading the coeffient for the variance
        
        for fua in tqdm(list_fuas):
            
            dir_output = os.path.join(dir_cov_mat,country)
            if os.path.exists(os.path.join(dir_output,f'{fua}_cov_mat.npz')):
                continue
            
            grid = pd.read_csv(os.path.join(dir_perimeter,country,f'indices_{fua}.csv'),header=0,names=['grid_ref','GRID_ID'])
            # if len(grid) > 4e4:
            #     print('Warning: too many contiguities for', country, fua)
            #     continue
            demog = pd.read_parquet(os.path.join(dir_demog,country,f'{fua}.parquet'))
            var_coef_fua = theoretical_variance_resmix(grid,var_coef,city_statistics,fua)
            var_coef_fua['grid_ref'] = var_coef_fua['grid_ref'].astype(demog['grid_ref'].dtype)
            var_coef_fua['var_coef'] = var_coef_fua['var_coef'].astype(FLOAT_PRECISION)
            proximity = pd.read_parquet(os.path.join(dir_proximity,country,f'{fua}.parquet'))

            demog = demog.merge(var_coef_fua, on = 'grid_ref')

            demog = demog.set_index('grid_ref')
            cov_mat = compute_cov_mat(proximity,demog['pop'],demog['var_coef'])
            
            if os.path.exists(dir_output) == False:
                    os.makedirs(dir_output)
            sparse.save_npz(os.path.join(dir_output,f'{fua}_cov_mat.npz'), cov_mat)
            del cov_mat
            

if __name__ == '__main__':
    main()