# -*- coding: utf-8 -*-
#                                                     #
#  __author__ = Adarsh Kalikadien                     #
#  __institution__ = TU Delft                         #
#  __contact__ = a.v.kalikadien@tudelft.nl            #

import os
import glob
import time
import tqdm

import pandas as pd
from ase.io import read
from fairchem.core import pretrained_mlip, FAIRChemCalculator
# for every ligand in the dataframe use the file column to find the DFT .log name, filter restart and SP from this.
# then use the metal center (either Ir, Mn or Ru in the filename) and axial_ligand column to find the xyz file
# e.g. the .xyz file for [Ir+3]_L17_OH0_DFT_SP.log is in Ir/H-N_axial/[Ir+3]_L17_OH0_DFT.xyz

def get_xyz_file_path(row):
    filename = row['file']
    # extract metal_center from filename (e.g. Ir from [Ir+3]_L17_OH1_restart_DFT_SP.log)
    # use regex to only find letters inside the brackets
    metal_center = ''.join(filter(str.isalpha, filename.split('_')[0]))  # e.g. Ir, Mn or Ru
    axial_ligand = row['axial_ligands']
    # construct the path to the xyz file, remove restart and/or SP from the filename
    xyz_filename = filename.replace('_restart', '').replace('_SP', '').replace('.log', '.xyz')
    # construct the path to the xyz file
    xyz_file_path = f'{metal_center}/{axial_ligand}_axial/{xyz_filename}'
    return xyz_file_path

# iterate over rows and calculate energies, add to same row in the dataframe or Nan if the file does not exist
def calculate_energy(row):
    xyz_file_path = row['xyz_file_path']
    if not os.path.exists(xyz_file_path):
        return None  # or np.nan if you prefer
    try:
        structure = read(xyz_file_path)
        structure.info.update({"spin": 1, "charge": 0})  # Assuming singlet state
        predictor = pretrained_mlip.get_predict_unit("uma-sm", device="cuda")
        structure.calc = FAIRChemCalculator(predictor, task_name="omol")
        energy_kJmol = structure.get_potential_energy() * 96.485 # convert from eV to kJ/mol
        return energy_kJmol
    except Exception as e:
        print(f"Error processing {xyz_file_path}: {e}")
        return None

# activate the tqdm progress bar
tqdm.tqdm.pandas(desc="Calculating energies")
df = pd.read_csv('final_data_Ru.csv')
# concat the final_data_mn.csv and final_data_Ir.csv to df
df_mn = pd.read_csv('final_data_mn_fixed.csv')
df_ir = pd.read_csv('final_data_Ir.csv')
df = pd.concat([df, df_mn, df_ir])
print('Length of dataset:', len(df))
# apply the function to the dataframe to create a new column with the xyz file paths
df['xyz_file_path'] = df.apply(get_xyz_file_path, axis=1)
# apply the function to the dataframe to create a new column with the energies
df['energy_UMA_SP'] = df.progress_apply(calculate_energy, axis=1)
# save the dataframe to a csv file
df.to_csv('final_data_Ru_mn_Ir_UMA_SP.csv', index=False)
