# -*- coding: utf-8 -*-
#                                                     #
#  __author__ = Adarsh Kalikadien                     #
#  __institution__ = TU Delft                         #
#  __contact__ = a.v.kalikadien@tudelft.nl            #

# read final_data_Ru_mn_Ir_UMA_SP.csv for every 'ligand_key' the 'axial_ligands' with 'energy_diff' == 0 is the reference
# use this reference to substract the 'energy_UMA_SP' in the same manner in the dataframe such that the reference has an energy of 0 kJ/mol
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import spearmanr
import numpy as np

plt.rcParams.update({
    # Figure size and DPI
    'figure.figsize': (10, 6),  # Default figure size (in inches)
    'figure.dpi': 300,        # Higher resolution for better quality

    # Font settings
    'font.size': 12,          # General font size
    'font.family': 'sans-serif',  # Use a clean, sans-serif font
    'font.sans-serif': ['Arial'], # Specific font family (replace with preferred one)

    # Axes settings
    'axes.labelsize': 14,     # Font size for axis labels
    'axes.titlesize': 16,     # Font size for plot titles
    'axes.linewidth': 1.5,    # Thickness of axis lines
    'axes.grid': False,        # Enable grid lines
    'grid.alpha': 0.5,        # Transparency for grid lines

    # Tick settings
    'xtick.labelsize': 12,    # Font size for x-axis tick labels
    'ytick.labelsize': 12,    # Font size for y-axis tick labels
    'xtick.major.size': 5,    # Length of major x ticks
    'xtick.major.width': 1.5, # Thickness of major x ticks
    'ytick.major.size': 5,    # Length of major y ticks
    'ytick.major.width': 1.5, # Thickness of major y ticks

    # Legend settings
    'legend.fontsize': 12,    # Font size for legend text
    'legend.frameon': False,  # Remove legend frame

    # Line and marker settings
    'lines.linewidth': 2,     # Line thickness
    'lines.markersize': 8,    # Marker size

    # Save figure settings
    'savefig.dpi': 300,       # DPI for saving figures
    'savefig.transparent': True,  # Transparent background for saved figures
    # 'savefig.bbox': 'tight',  # Tight layout for saved figures
})

def plot_energy_comparison(df):
    plt.figure(figsize=(10, 6))
    plt.scatter(df['adjusted_energy_UMA_SP'], df['energy_diff'], alpha=0.5, color='blue')
    plt.title('Energy difference UMA vs DFT')
    plt.xlabel('UMA Energy Difference (kJ/mol)')
    plt.ylabel('DFT Energy Difference (kJ/mol)')
    plt.grid(True)
    plt.show()

def plot_energy_comparison_by_metal(df):
    # dropwith missing values
    df_clean = df.dropna(subset=['adjusted_energy_UMA_SP', 'energy_diff', 'element_Rh'])

    plt.figure(figsize=(10, 6))
    sns.scatterplot(
        data=df_clean,
        x='adjusted_energy_UMA_SP',
        y='energy_diff',
        hue='element_Rh',
        palette='Set1',
        alpha=0.7
    )

    # linear regression line across all data (not per group)
    m, b = np.polyfit(df_clean['adjusted_energy_UMA_SP'], df_clean['energy_diff'], 1)
    x_vals = np.linspace(df_clean['adjusted_energy_UMA_SP'].min(), df_clean['adjusted_energy_UMA_SP'].max(), 100)
    plt.plot(x_vals, m * x_vals + b, color='red', linestyle='-')

    # R² and RMSE
    r_squared = np.corrcoef(df_clean['adjusted_energy_UMA_SP'], df_clean['energy_diff'])[0, 1] ** 2
    rmse = np.sqrt(np.mean((df_clean['adjusted_energy_UMA_SP'] - df_clean['energy_diff']) ** 2))

    plt.xlabel('UMA Energy Difference (kJ/mol)')
    plt.ylabel('DFT Energy Difference (kJ/mol)')
    # add combined legend
    plt.legend(title=f'R² = {r_squared:.2f}\nRMSE = {rmse:.2f} kJ/mol')
    plt.grid(True)
    plt.tight_layout()
    plt.savefig('UMA_vs_DFT_configurations.png', bbox_inches='tight', dpi=300)
    plt.show()

def plot_energy_comparison_by_configuration(df):
    # drop rows with missing values
    df_clean = df.dropna(subset=['adjusted_energy_UMA_SP', 'energy_diff', 'axial_ligands'])

    # scatter plot colored by configuration
    plt.figure(figsize=(10, 6))
    sns.scatterplot(
        data=df_clean,
        x='adjusted_energy_UMA_SP',
        y='energy_diff',
        hue='axial_ligands',
        palette='tab10',  # or 'husl' for more distinct hues
        alpha=0.7
    )
    plt.title('Energy Differences UMA vs DFT by configuration')
    plt.xlabel('UMA Energy Difference (kJ/mol)')
    plt.ylabel('DFT Energy Difference (kJ/mol)')
    plt.legend(title='Configuration', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True)
    plt.tight_layout()
    plt.show()


def adjust_energies(df):
    # reference rows: lowest energy_diff per (ligand_key, metal) group (i.e., energy_diff == 0)
    ref_df = df[df['energy_diff'] == 0][['ligand_key', 'element_Rh', 'energy_UMA_SP']].copy()
    ref_df = ref_df.rename(columns={'energy_UMA_SP': 'ref_energy_UMA_SP'})

    # merge reference energies back into the main dataframe
    df = df.merge(ref_df, on=['ligand_key', 'element_Rh'], how='left')

    # compute adjusted energy (relative to reference per metal-ligand combo)
    df['adjusted_energy_UMA_SP'] = df['energy_UMA_SP'] - df['ref_energy_UMA_SP']

    return df

def analyze_spearman_per_ligand_metal(df, ligand_col='ligand_key', metal_col='element_Rh',
                                      uma_col='adjusted_energy_UMA_SP', dft_col='energy_diff',
                                      min_configs=3, rho_threshold=0.6,
                                      exclude_metals=['Ir']):
    results = []

    for (metal, ligand), group in df.groupby([metal_col, ligand_col]):
        if metal in exclude_metals or len(group) < min_configs:
            continue

        rho, pval = spearmanr(group[uma_col], group[dft_col])
        results.append({
            'element_Rh': metal,
            'Ligand#': ligand,
            'n_configs': len(group),
            'spearman_rho': rho,
            'spearman_pval': pval,
        })

    spearman_df = pd.DataFrame(results)
    spearman_df['trusted'] = (
        (spearman_df['spearman_pval'] < 0.05) &
        (spearman_df['spearman_rho'] >= rho_threshold)
    )
    return spearman_df

def plot_spearman_barplot(spearman_df):
    spearman_df['combo'] = spearman_df['element_Rh'] + "_" + spearman_df['Ligand#']
    spearman_df['Ligand_num'] = spearman_df['Ligand#'].str.extract(r'L(\d+)').astype(float)
    spearman_df = spearman_df.sort_values(['element_Rh', 'Ligand_num'])

    plt.figure(figsize=(14, 6))
    sns.barplot(data=spearman_df, x='combo', y='spearman_rho',
                hue='trusted', palette={True: 'dodgerblue', False: 'lightgrey'})

    plt.axhline(0.6, color='red', linestyle='--', label='ρ = 0.6')
    plt.xticks(rotation=90)
    plt.ylabel("Spearman's ρ (UMA vs DFT)")
    plt.title("Spearman Ranking per (Metal, Ligand)")
    plt.legend(title='Trusted')
    plt.tight_layout()
    plt.show()

def plot_spearman_by_metal(spearman_df, rho_threshold=0.6):
    metals = spearman_df['element_Rh'].unique()

    fig, axes = plt.subplots(len(metals), 1, figsize=(14, 5 * len(metals)), sharex=False)

    for i, metal in enumerate(sorted(metals)):
        subset = spearman_df[spearman_df['element_Rh'] == metal].copy()
        subset['Ligand_num'] = subset['Ligand#'].str.extract(r'L(\d+)').astype(float)
        subset = subset.sort_values('Ligand_num')

        sns.barplot(
            data=subset,
            x='Ligand#',
            y='spearman_rho',
            hue='trusted',
            palette={True: 'dodgerblue', False: 'lightgrey'},
            ax=axes[i]
        )

        axes[i].axhline(rho_threshold, color='red', linestyle='--', label='ρ = 0.6')
        axes[i].set_title(f'Spearman Ranking for Metal Center: {metal}')
        axes[i].set_ylabel("Spearman's ρ")
        axes[i].tick_params(axis='x', rotation=90)
        axes[i].legend(title='Trusted', loc='lower right')

    plt.tight_layout()
    plt.show()

def plot_spearman_heatmap(spearman_df):
    # prep pivot table
    pivot = spearman_df.pivot(index='Ligand#', columns='element_Rh', values='spearman_rho').copy()

    # sort ligand rows numerically
    ligand_order = pivot.index.to_series().str.extract(r'L(\d+)')[0].astype(int)
    pivot['Ligand_num'] = ligand_order
    pivot = pivot.sort_values('Ligand_num').drop(columns='Ligand_num')

    # heatmap with red-to-green scale: Red (bad), Green (good)
    cmap = sns.color_palette("RdYlGn", as_cmap=True)

    plt.figure(figsize=(12, 12))
    sns.heatmap(pivot, annot=True, cmap=cmap, center=0.5, vmin=0, vmax=1)
    # plt.title("Spearman ρ (UMA vs DFT) per Ligand and Metal Center")
    plt.ylabel("Ligand Number")
    plt.xlabel("Metal Center")
    plt.tight_layout()
    plt.savefig('configurations_spearman_heatmap_per_ligand_metal.png', bbox_inches='tight', dpi=300)
    plt.show()

def print_trusted_summary(spearman_df):
    total_trusted = spearman_df['trusted'].sum()
    total = len(spearman_df)
    print(f"Trusted combinations overall: {total_trusted} / {total} "
          f"({total_trusted / total * 100:.1f}%)")
    # print Ligand#, element_Rh and rho for not trusted combinations
    not_trusted = spearman_df[~spearman_df['trusted']]
    if not not_trusted.empty:
        print("\nNot trusted combinations:")
        for _, row in not_trusted.iterrows():
            print(f"  {row['Ligand#']} - {row['element_Rh']}: ρ = {row['spearman_rho']:.2f}, "
                  f"p-value = {row['spearman_pval']:.2e}")

    # Group by metal and compute trusted % per metal
    summary = spearman_df.groupby('element_Rh')['trusted'].agg(['sum', 'count'])
    summary['percentage'] = summary['sum'] / summary['count'] * 100

    print("Trusted combinations per metal center:")
    for metal, row in summary.iterrows():
        print(f"  {metal}: {int(row['sum'])} / {int(row['count'])} trusted "
              f"({row['percentage']:.1f}%)")

# Load the dataset
df = pd.read_csv('final_data_Ru_mn_Ir_UMA_SP.csv')
# print amount of unique Ligand# in the dataframe
print(f"Number of unique Ligand# in the dataframe: {df['ligand_key'].nunique()}")
# plot comparisons
df = adjust_energies(df) # make sure UMA energies are also relative to the same reference
# plot_energy_comparison(df)
plot_energy_comparison_by_metal(df)
# plot_energy_comparison_by_configuration(df)
spearman_df = analyze_spearman_per_ligand_metal(df)
# plot_spearman_by_metal(spearman_df)
plot_spearman_heatmap(spearman_df)
print_trusted_summary(spearman_df)