# -*- coding: utf-8 -*-
#                                                     #
#  __author__ = Adarsh Kalikadien                     #
#  __institution__ = TU Delft                         #
#  __contact__ = a.v.kalikadien@tudelft.nl            #

import os
import time
import tqdm
import pandas as pd
from ase.io import read
from fairchem.core import pretrained_mlip, FAIRChemCalculator

def calculate_uma_energies_for_conformers(model_structure):
    """
    Calculate UMA single-point energies for Ni conformers and save updated descriptor CSV.
    :param model_structure: 'Cl2' or 'substrate'
    """
    if model_structure not in ['Cl2', 'substrate']:
        raise ValueError("model_structure must be either 'Cl2' or 'substrate'")

    input_csv = f'ligand_ni_{model_structure.lower()}_complexes_dft_descriptors_per_conformer_with_xtb_energy.csv'
    df = pd.read_csv(input_csv)

    predictor = pretrained_mlip.get_predict_unit("uma-sm", device="cuda")

    def get_xyz_path(row):
        ligand_num = row['Ligand#']
        conformer_num = int(row['Conformer#'])
        filename = f"ce{ligand_num}_structure_{conformer_num}_DFT.xyz"
        return os.path.join(model_structure, filename)

    # compute UMA energy
    def compute_uma_energy(row):
        xyz_path = row['xyz_file_path']
        if not os.path.exists(xyz_path):
            return None
        try:
            mol = read(xyz_path)
            mol.info.update({"spin": 1, "charge": 0})
            mol.calc = FAIRChemCalculator(predictor, task_name="omol")
            energy_kjmol = mol.get_potential_energy() * 96.485
            return energy_kjmol
        except Exception as e:
            print(f"Error reading {xyz_path}: {e}")
            return None

    tqdm.tqdm.pandas(desc=f"Computing UMA energies for {model_structure}")
    df['xyz_file_path'] = df.apply(get_xyz_path, axis=1)
    df['energy_UMA'] = df.progress_apply(compute_uma_energy, axis=1)

    output_csv = f'ligand_ni_{model_structure.lower()}_complexes_dft_descriptors_per_conformer_with_uma_dft_sp.csv'
    df.to_csv(output_csv, index=False)
    print(f"Saved updated dataframe to: {output_csv}")

if __name__ == "__main__":
    start_time = time.time()
    model_structure = 'substrate'  # or 'substrate'
    calculate_uma_energies_for_conformers(model_structure)
    print(f"Completed in {time.time() - start_time:.2f} seconds.")
