# -*- coding: utf-8 -*-
#                                                     #
#  __author__ = Adarsh Kalikadien                     #
#  __institution__ = TU Delft                         #
#  __contact__ = a.v.kalikadien@tudelft.nl            #
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
import numpy as np
import seaborn as sns
from scipy.stats import spearmanr

plt.rcParams.update({
    # Figure size and DPI
    'figure.figsize': (10, 6),  # Default figure size (in inches)
    'figure.dpi': 300,        # Higher resolution for better quality

    # Font settings
    'font.size': 12,          # General font size
    'font.family': 'sans-serif',  # Use a clean, sans-serif font
    'font.sans-serif': ['Arial'], # Specific font family (replace with preferred one)

    # Axes settings
    'axes.labelsize': 14,     # Font size for axis labels
    'axes.titlesize': 16,     # Font size for plot titles
    'axes.linewidth': 1.5,    # Thickness of axis lines
    'axes.grid': False,        # Enable grid lines
    'grid.alpha': 0.5,        # Transparency for grid lines

    # Tick settings
    'xtick.labelsize': 12,    # Font size for x-axis tick labels
    'ytick.labelsize': 12,    # Font size for y-axis tick labels
    'xtick.major.size': 5,    # Length of major x ticks
    'xtick.major.width': 1.5, # Thickness of major x ticks
    'ytick.major.size': 5,    # Length of major y ticks
    'ytick.major.width': 1.5, # Thickness of major y ticks

    # Legend settings
    'legend.fontsize': 12,    # Font size for legend text
    'legend.frameon': False,  # Remove legend frame

    # Line and marker settings
    'lines.linewidth': 2,     # Line thickness
    'lines.markersize': 8,    # Marker size

    # Save figure settings
    'savefig.dpi': 300,       # DPI for saving figures
    'savefig.transparent': True,  # Transparent background for saved figures
    # 'savefig.bbox': 'tight',  # Tight layout for saved figures
})

# create function to create scatter plot and linear regression of two columns in a dataframe
def scatter_plot_with_regression(df, x_col, y_col, title=None, xlabel=None, ylabel=None):
    """
    Create a scatter plot with linear regression line.

    :param df: DataFrame containing the data
    :param x_col: Column name for x-axis
    :param y_col: Column name for y-axis
    :param title: Title of the plot
    :param xlabel: Label for x-axis
    :param ylabel: Label for y-axis
    """
    plt.figure(figsize=(10, 6))
    plt.scatter(df[x_col], df[y_col], alpha=0.5, color='blue')

    # linear regression model
    m, b = np.polyfit(df[x_col], df[y_col], 1)
    plt.plot(df[x_col], m * df[x_col] + b, color='red')

    # if title:
    #     plt.title(title)
    if xlabel:
        plt.xlabel(xlabel)
    if ylabel:
        plt.ylabel(ylabel)
    # add R² and RMSE to the legend
    r_squared = np.corrcoef(df[x_col], df[y_col])[0, 1] ** 2
    rmse = np.sqrt(np.mean((df[x_col] - df[y_col]) ** 2))
    plt.legend(title=f'R² = {r_squared:.2f}\nRMSE = {rmse:.2f} kJ/mol')
    plt.grid()
    plt.savefig(f'{title.replace(" ", "_").replace(":", "")}.png', bbox_inches='tight', dpi=300)
    plt.show()


def analyze_spearman_per_ligand(df, ligand_col='Ligand#', uma_col='energy_UMA', dft_col='E', min_confs=4, rho_threshold=0.6):
    spearman_results = []

    for ligand, group in df.groupby(ligand_col):
        if len(group) < min_confs:
            continue  # Skip ligands with <min_conf conformers

        rho, pval = spearmanr(group[uma_col], group[dft_col])
        spearman_results.append({
            'Ligand#': ligand,
            'n_conformers': len(group),
            'spearman_rho': rho,
            'spearman_pval': pval
        })

    spearman_df = pd.DataFrame(spearman_results)

    # add label
    spearman_df['trusted'] = (
        (spearman_df['n_conformers'] >= min_confs) &
        (spearman_df['spearman_pval'] < 0.05) &
        (spearman_df['spearman_rho'] >= rho_threshold)
    )

    return spearman_df

def adjust_energies(df):
    # reference rows: lowest energy_diff per (ligand_key, metal) group (i.e., energy_diff == 0)
    ref_df = df[df['E_rel'] == 0][['Ligand#', 'energy_UMA']].copy()
    ref_df = ref_df.rename(columns={'energy_UMA': 'ref_energy_UMA'})

    # merge reference energies back into the main dataframe
    df = df.merge(ref_df, on=['Ligand#'], how='left')

    # compute adjusted energy (relative to reference per metal-ligand combo)
    df['adjusted_energy_UMA'] = df['energy_UMA'] - df['ref_energy_UMA']

    return df

def plot_spearman_ranking(spearman_df, model_structure, title):
    plt.figure(figsize=(12, 6))
    # sorting based on Ligand# numeric value
    spearman_df['Ligand_num'] = spearman_df['Ligand#'].str.extract(r'L(\d+)').astype(int)
    spearman_df_sorted = spearman_df.sort_values('Ligand_num')

    sns.barplot(data=spearman_df_sorted,
                x='Ligand#', y='spearman_rho',
                hue='trusted', palette={True: 'dodgerblue', False: 'lightgrey'})
    spearman_df.drop(columns='Ligand_num', inplace=True, errors='ignore')
    plt.axhline(0.6, color='red', linestyle='--', label='ρ = 0.6 threshold')
    plt.xticks(rotation=90)
    plt.ylabel("Spearman's ρ (UMA vs DFT)")
    # plt.title(title)
    custom_legend = [
        Patch(color='dodgerblue', label='Reliable'),
        Patch(color='lightgrey', label='Not Reliable'),
        Line2D([0], [0], color='red', linestyle='--', label='ρ = 0.6 threshold')
    ]
    plt.legend(handles=custom_legend)
    plt.tight_layout()
    plt.savefig(f'spearman_ranking_per_ligand_{model_structure}.png', bbox_inches='tight', dpi=300)
    plt.show()

if __name__ == '__main__':
    df = pd.read_csv('ligand_ni_cl2_complexes_dft_descriptors_per_conformer_with_uma_dft_sp.csv')
    # only keep Ligand# L2, 4, 5, 6, 7, 8, 11, 12, 25, 27, 28, 29, 30, 31, 32, 34, 35, 36, 39 to keep the common ligands
    list_of_common_ligands_in_datasets = ['L2', 'L4', 'L5', 'L6', 'L7', 'L8', 'L11',
                                'L12', 'L25', 'L27', 'L28', 'L29', 'L30',
                                'L31', 'L32', 'L34', 'L35', 'L36', 'L39']
    # print the total amount of conformers in the dataset
    print(f"Total conformers in dataset: {len(df)}")
    df = df[df['Ligand#'].isin(list_of_common_ligands_in_datasets)]
    # convert E column from hartree to kJ/mol
    df['E'] = df['E'] * 2625.5  # convert hartree to kJ/mol
    df = adjust_energies(df)
    # plot x = energy_UMA, y = E
    # scatter_plot_with_regression(df, 'energy_UMA', 'E',
    #                              title='Cl2 complexes UMA Energy vs DFT Energy',
    #                              xlabel = r'$\Delta E_\mathrm{UMA}$ (kJ/mol)',
    #                              ylabel= r'$\Delta E_\mathrm{DFT}$ (kJ/mol)')
    scatter_plot_with_regression(df, 'adjusted_energy_UMA', 'E_rel',
                                 title='Precatalyst model structures',
                                 xlabel=r'$\Delta E_\mathrm{UMA}$ (kJ/mol)',
                                 ylabel= r'$\Delta E_\mathrm{DFT}$ (kJ/mol)')
    # run Spearman ranking analysis and plot
    spearman_df = analyze_spearman_per_ligand(df)
    plot_spearman_ranking(spearman_df, 'precatalyst_model_structures', 'Precatalyst model structures')

    # print summary stats
    print(f"\nTrusted ligands: {spearman_df['trusted'].sum()} / {len(spearman_df)}")

    df = pd.read_csv('ligand_ni_substrate_complexes_dft_descriptors_per_conformer_with_uma_dft_sp.csv')
    df = df[df['Ligand#'].isin(list_of_common_ligands_in_datasets)]
    # print the total amount of conformers in the dataset
    print(f"Total conformers in dataset: {len(df)}")
    # convert E column from hartree to kJ/mol
    df['E'] = df['E'] * 2625.5  # convert hartree to kJ/mol
    df = adjust_energies(df)
    # only pick E values below 15 million kJ/mol
    df = df[df['E'] < 15000000]  # filter out unrealistic high energies
    # plot x = energy_UMA, y = E
    # scatter_plot_with_regression(df, 'energy_UMA', 'E',
    #                              title='substrate complexes UMA Energy vs DFT Energy',
    #                              xlabel='UMA Energy (kJ/mol)',
    #                              ylabel='DFT Energy (kJ/mol)')
    scatter_plot_with_regression(df, 'adjusted_energy_UMA', 'E_rel',
                                    title='Activated catalyst model structures',
                                     xlabel=r'$\Delta E_\mathrm{UMA}$ (kJ/mol)',
                                     ylabel=r'$\Delta E_\mathrm{DFT}$ (kJ/mol)')
    # spearman ranking analysis and plot
    spearman_df = analyze_spearman_per_ligand(df)
    plot_spearman_ranking(spearman_df, 'activated_catalyst_model_structures', 'Activated catalyst model structures')

    print(f"\nTrusted ligands: {spearman_df['trusted'].sum()} / {len(spearman_df)}")