import numpy as np
from scipy.interpolate import CubicSpline
import scipy as sc
from itertools import product
from itertools import cycle
#from numba import njit
import datetime
import scipy.interpolate as sc
from collections import defaultdict
from matplotlib.ticker import ScalarFormatter
import pandas as pd
import math
import matplotlib.pyplot as plt

R = 2.4e-9 #m
T=298 # temperature [K]
kb=1.380649*10**(-23) # Boltzmann_constant [J⋅K−1]
Nav= 6.02214076*10**23 # Avogadro's number [mol-1]

ct = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') # ct stores current time
ct_str=str(ct)

def get_column_contents(file_path, column_name):
    # Read the file into a DataFrame
    df = pd.read_csv(file_path, sep='\t', lineterminator='\n')
    
    # Check if the column exists
    if column_name not in df.columns:
        raise ValueError(f'Column {column_name} not found in the file.')
    
    # Return the contents of the column
    return np.flip(df[column_name].values)

def format_scientific(num):
    if num == 0:
        return "0"
    power = int(math.floor(math.log10(abs(num))))
    coefficient = num / (10 ** power)
    # return f'{coefficient}' + r'$\cdot 10^{' + str(power) + '}$'
    return r'$10^{' + str(power) + '}$'

    
def get_data_flat(Np, Csalt, sigma_f, pH, sigma_p, m, pKa, ChiPS, b):
     ## the first pH value is the minimum possible with the specific salt concentration
    if Csalt==0.001:
        pH0=3.0
    elif Csalt == 0.1 or Csalt == 0.01:
        pH0=2.0
    elif Csalt == 0.0001:
        pH0=4.0
    elif Csalt == 0.00001:
        pH0=5.0
    base_path=r'./kal_files'
    assert(sigma_f==0)
    assert(sigma_p!=0)
    if sigma_f == 0 and sigma_p != 0:
        if pKa==4.5:
            print("read files ", pH0,pH)
            file0 = base_path+r'\flat_surf_pol_cr_NP=100_Csalt='+str(Csalt)+'_M_pH='+str(pH0)+'_Sigma='+str(sigma_p)+'_mol_nm2_ChiPS='+str(ChiPS)+'_blanco.kal'
            file1 = base_path+r'\flat_surf_pol_cr_NP=100_Csalt='+str(Csalt)+'_M_pH='+str(pH)+'_Sigma='+str(sigma_p)+'_mol_nm2_ChiPS='+str(ChiPS)+'_blanco.kal'
        else:
            file0 = base_path+r'\flat_surf_pol_cr_NP=100_Csalt='+str(Csalt)+'_M_pH='+str(pH0)+'_Sigma='+str(sigma_p)+'_mol_nm2_ChiPS='+str(ChiPS)+'_pKa='+str(pKa)+'_blanco.kal'
            file1 = base_path+r'\flat_surf_pol_cr_NP=100_Csalt='+str(Csalt)+'_M_pH='+str(pH)+'_Sigma='+str(sigma_p)+'_mol_nm2_ChiPS='+str(ChiPS)+'_pKa='+str(pKa)+'_blanco.kal'
    sigma_p_m2=sigma_p*1e18
    M_flat = b*sigma_p_m2*b #number of polymer chains
    theta_flat = M_flat*2.*Np
    lmin_flat = int(np.ceil(theta_flat)) + 1
    lmax_flat = 2*Np + 1

    ################ flat ###################################################################################
    gp0_sfbox = get_column_contents(file0, 'sys : noname : grand potential')
    layers0 = get_column_contents(file0, 'lat : solution : n_layers')
    
    gp0_real = (
        gp0_sfbox
        - (get_column_contents(file0, 'mol : solvent : FH-MU - H3O')
        * get_column_contents(file0, 'state : PH : theta') if (m != 0) and (sigma_p != 0) else 0)
        - (get_column_contents(file0, 'mol : solvent : FH-MU - H3O')
        * get_column_contents(file0, 'state : funcH : theta') if sigma_f != 0 else 0)

        )
    en_po_from_real_gp0 = (
            gp0_real
            # POLYMER
            + (get_column_contents(file0, 'mol : poly : sum n FH-MU') if sigma_p != 0 else 0) # RESTRICTED
            
            # SURFACE GROUPS
            + (get_column_contents(file0, 'mol : funcsurface : sum n FH-MU') if sigma_f != 0 else 0) # RESTRICTED
        )
    en0 = en_po_from_real_gp0

    # Load and process data for file1
    layers1 = get_column_contents(file1, 'lat : solution : n_layers')
    gp_sfbox = get_column_contents(file1, 'sys : noname : grand potential')
    
    gp_real = (
        gp_sfbox
        - (get_column_contents(file1, 'mol : solvent : FH-MU - H3O')
        * get_column_contents(file1, 'state : PH : theta') if (m != 0) and (sigma_p != 0) else 0)
        - (get_column_contents(file1, 'mol : solvent : FH-MU - H3O')
        * get_column_contents(file1, 'state : funcH : theta') if sigma_f != 0 else 0)

        )
    en_po_from_real_gp = (
            gp_real
            # POLYMER
            + (get_column_contents(file1, 'mol : poly : sum n FH-MU') if sigma_p != 0 else 0) # RESTRICTED
            
            # SURFACE GROUPS
            + (get_column_contents(file1, 'mol : funcsurface : sum n FH-MU') if sigma_f != 0 else 0) # RESTRICTED
        )
    en1 = en_po_from_real_gp

    assert(len(en0) == len(layers0))
    assert(len(en1) == len(layers1))

    ###### in the first computation the pH is minimum, so no work/disjoining pressure
    if(pH0==pH):
        assert(en1.all()==en0.all())
    
    # compute pressure, stroke, work
    delta0 = (layers0 + 1.) / lmax_flat
    diff_energy0 = (-en0[:-1] + en0[1:]) * kb * T / b
    diff_vol0 = b * b
    pres0 = - (diff_energy0 / diff_vol0)
    deltapress0 = (delta0[1:] + delta0[:-1]) * 0.5
    delta_interp = np.linspace(0., 1., 1000)
    line0 = sc.interp1d(deltapress0, pres0, fill_value='extrapolate')
    press0_lin = line0(delta_interp)
    
    delta1 = (layers1 + 1.) / lmax_flat
    diff_energy1 = (-en1[:-1] + en1[1:]) * kb * T / b
    diff_vol1 = b * b
    pres1 = - (diff_energy1 / diff_vol1)
    deltapress1 = (delta1[1:] + delta1[:-1]) * 0.5
    delta_interp = np.linspace(0., 1., 1000)
    line1 = sc.interp1d(deltapress1, pres1, fill_value='extrapolate')
    press1_lin = line1(delta_interp)

    delta0lin = delta_interp[np.argmin(abs(press0_lin - 0.))]
    delta1lin = delta_interp[np.argmin(abs(press1_lin - 0.))]

    deltap_pro3 = press0_lin[np.argmin(abs(press1_lin - 0.))] - press1_lin[np.argmin(abs(press1_lin - 0.))]
    deltap_dep3 = press1_lin[np.argmin(abs(press0_lin - 0.))] - press0_lin[np.argmin(abs(press0_lin - 0.))]
    
    dx = 0.001
    y_heights0 = press0_lin[np.argmin(abs(press0_lin - 0.)) :  np.argmin(abs(press1_lin - 0.))]
    y_heights1 = press1_lin[np.argmin(abs(press0_lin - 0.)) :  np.argmin(abs(press1_lin - 0.))]
    int_bad = (np.sum(y_heights1*dx) - np.sum(y_heights0*dx))*(b**3*lmax_flat)

######## this is for figure 3   #########################################################
    if(Csalt==0.001 and pH==9):
        plt.figure()
        plt.tick_params(axis='both', which='major', labelsize=15) 
        mean = (delta1lin + delta0lin)*0.5
        start_delta=np.argmin(abs(delta0lin - delta_interp))
        stop_delta=np.argmin(abs(delta1lin - delta_interp))
        arrow_pos=int((start_delta+stop_delta)/2)
        print("arrow pos index", arrow_pos,"arrow pos", delta_interp[arrow_pos])
        plt.plot(deltapress0,pres0*1e-6, label=r"pH = 3", linestyle='dashed', color="C0", lw=2 )
        plt.plot(deltapress1,pres1*1e-6, label=r"pH = 9", color="grey", lw=2)
        plt.axhline(0, c='darkgrey')
        plt.ylim(-20,30)
        plt.text(delta0lin  - 0.1,20,r'$\Delta \Pi_{\mathrm{dep}}$',rotation=0, fontsize=12, color="k")
        plt.text(delta1lin  + 0.02, -7.5,r'$\Delta \Pi_{\mathrm{pro}}$',rotation=0, fontsize=12, color="k")
        plt.plot(delta_interp[start_delta:stop_delta], press0_lin[start_delta:stop_delta]*1e-6, c="k", lw=2.5)
        plt.plot(delta_interp[start_delta:stop_delta], press1_lin[start_delta:stop_delta]*1e-6, c="k",lw=2.5)
        plt.vlines(delta0lin, press0_lin[np.argmin(abs(press0_lin - 0.))]*1e-6, press1_lin[np.argmin(abs(press0_lin - 0.))]*1e-6, color="k",lw=2.5)
        plt.vlines(delta1lin, press0_lin[np.argmin(abs(press1_lin - 0.))]*1e-6,  press1_lin[np.argmin(abs(press1_lin - 0.))]*1e-6, color="k",lw=2.5)
        # Adding arrows
        arrow_style = dict(facecolor="k", edgecolor="k", lw=4)
        plt.annotate('', xy=(delta_interp[arrow_pos], press0_lin[arrow_pos]*1e-6),xytext=(delta_interp[arrow_pos+20], press0_lin[arrow_pos+20]*1e-6), arrowprops=dict(facecolor="k", shrink=0.1))
        plt.annotate('', xy=(delta_interp[arrow_pos+30], press1_lin[arrow_pos+30]*1e-6),xytext=(delta_interp[arrow_pos], press1_lin[arrow_pos]*1e-6),arrowprops=dict(facecolor="k", shrink=0.1))
        plt.annotate('', xy=(delta0lin, 10),xytext=(delta0lin, 8),arrowprops=dict(facecolor="k", shrink=0.1))
        plt.annotate('', xy=(delta1lin, -4),xytext=(delta1lin, -2), arrowprops=dict(facecolor="k", shrink=0.7))
        #arrow to indicate that the shaded area is the work
        plt.arrow(x=0.475, y=1, dx=0.05, dy=12, head_width=0.01, head_length=0.5, fc='k', ec='k')
        plt.text(0.51,14,"work in one cycle",rotation=0, fontsize=15, color="k")
        plt.text(0.56,1,r"$\delta_{\mathrm{A}}$",rotation=0, fontsize=15, color="k")
        plt.text(0.22,-3,r"$\delta_{\mathrm{B}}$",rotation=0, fontsize=15, color="k")
        plt.axhline(0.,c="grey")
        plt.fill_between(
            x= deltapress0, 
            y1= pres0*1e-6, 
            where= abs(deltapress0-mean)<(delta1lin - delta0lin)*0.5 ,
            color='#22b573',
            alpha= 0.2)
        plt.fill_between(
            x= deltapress1, 
            y1= pres1*1e-6, 
            where= abs(deltapress1-mean)<(delta1lin - delta0lin)*0.5 ,
            color='#22b573',
            alpha= 0.2)
        plt.legend(fontsize=15)
        plt.xlabel(r"Normalized strain $\delta$", fontsize=15)
        plt.ylabel(r"Pressure $\Pi$ [MPa]", fontsize=15)
        plt.savefig("PAPERpress_curves_flat2.png",bbox_inches="tight")
        plt.savefig("PAPERpress_curves_flat2.pdf",bbox_inches="tight")
#########################################################################################################################################

    index_delta1lin = int((1-delta1lin)*lmax_flat) - 1
    PH_total = get_column_contents(file1, 'state : PH : theta')
    PH_total = PH_total[index_delta1lin]
    Pm_total = get_column_contents(file1, 'state : Pm : theta')
    Pm_total = Pm_total[index_delta1lin]
    Poly_total = get_column_contents(file1, 'mol : poly : theta')
    Poly_total = Poly_total[index_delta1lin]
    P2_total = get_column_contents(file1, 'mon : P2 : theta')
    P2_total = P2_total[index_delta1lin]

    return {
        'geometry': 'flat', 
        'Np': Np,
        'Csalt': Csalt,
        'sigma_f': sigma_f,
        'pH': pH,
        'sigma_p': sigma_p,
        'm': m,
        'int_bad': int_bad,
        'delta0lin': delta0lin,
        'delta1lin': delta1lin,
        'deltap_pro3': deltap_pro3,
        'deltap_dep3': deltap_dep3,
        'M': M_flat,
        'Poly_total': Poly_total,
        'Pm_total': Pm_total,
        'PH_total': PH_total,
        'P2_total': P2_total,
        'pKa': pKa,
        'ChiPS': ChiPS,
        # Include other relevant data as needed
    }

def main():
    b = 0.3e-9
    Np = 100
    sigma_f=0
    sigma_p =1
    m = 0.5
    pKa=4.5
    ChiPS = 0.4
    Csalt_range = [0.1, 0.01, 0.001, 0.0001, 0.00001]
    geometry = "flat"
    results=[]

    for Csalt in Csalt_range:
        pH_range = np.round(np.arange(max(2, - np.log10(Csalt)), 13.0, 1.0), 1)
        if Csalt == 0.01:
            pH_range = np.round(np.arange(max(2, - np.log10(Csalt)), 12.0, 1.0), 1)
        print(Csalt,pH_range)
        for pH in pH_range:
            if geometry == 'cylindrical':
                print("wrong geometry!")
            elif geometry == 'flat':
                new_entry = get_data_flat(Np, Csalt, sigma_f, pH, sigma_p, m, pKa, ChiPS, b)
                results.append(new_entry.values())

    df = pd.DataFrame(results, columns=new_entry.keys())

    # Save to Excel
    excel_filename = 'results_'+geometry+'.xlsx'
    df.to_excel(excel_filename, index=False)

    print(f'Results saved to {excel_filename}')

if __name__ == '__main__':
    main()
       


    
