# -*- coding: utf-8 -*-
"""
Created on Wed Nov  6 15:53:51 2024

@author: bagch002
"""

import statistics
from scipy.signal import savgol_filter
from scipy.signal import find_peaks
from scipy.optimize import curve_fit
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import csv
import tkinter as tk
from tkinter import filedialog
import math
import pandas as pd
from glob import glob
import os
import os.path
import easygui as egui
import re
from scipy.ndimage import gaussian_filter
from matplotlib.gridspec import GridSpec
from functools import reduce  # Used for multiply the value in a list, look at sizeImg
from tkinter import filedialog
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.ticker import FuncFormatter

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.ticker import FuncFormatter
import numpy as np  # Needed for logarithm calculation

def colorbar(mappable):
    last_axes = plt.gca()
    ax = mappable.axes
    fig = ax.figure
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(mappable, cax=cax)

    # Determine the scaling factor (power of 10)
    vmin, vmax = cbar.vmin, cbar.vmax
    exponent = int(np.floor(np.log10(max(vmax, 1))))  # Avoid log(0)

    # Scale down the colorbar labels
    scale_factor = 10**exponent if exponent >= 4 else 1  # Only scale for large numbers

    def scaled_formatter(value, _):
        return f"{value / scale_factor:.1f}"  # Display scaled values

    cbar.ax.yaxis.set_major_formatter(FuncFormatter(scaled_formatter))

    # Set the title to show the scaling factor
    if exponent >= 4:
        cbar.ax.set_title(f"1e{exponent}", fontsize=12, pad=10)

    plt.sca(last_axes)
    return cbar
 
def scalingImage(self, dict_header, img_pixels):
    # Get the x and y scaling of image (nm and ps)
    address_scalingXfile = int(re.search("#([0-9]*),", dict_header["ScalingXScalingFile"]).group(
        1))  # Where the scaling file is written in the file
    address_scalingYfile = int(re.search("#([0-9]*),", dict_header["ScalingYScalingFile"]).group(1))
    # read the x scaling file
    x_scale = np.zeros(img_pixels)  # should be changed in case of other image size
    with open(self, "rb") as f:
        f.seek(address_scalingXfile)
        for i in range(img_pixels):
            data = f.read(4)
            x_scale[i] = np.frombuffer(data, np.float32)[0]
    # read the y scaling file
    y_scale = np.zeros(img_pixels)
    with open(self, "rb") as f:
        f.seek(address_scalingYfile)
        for i in range(img_pixels):
            data = f.read(4)
            y_scale[i] = np.frombuffer(data, np.float32)[0]
    return x_scale, y_scale
 
 
def read_header(self):
    header_buffer = self.read(8000)  # Read a buffer of 4000 arbitrary byte
    start_header = header_buffer.find(b'[Application]')  # Find starting position of header which start with Application
    end_header = header_buffer.find(
        b'UserComment=""') + 14  # Find the position of UserComment and add 14 for the string UserComment=""
 
    self.seek(start_header)  # Get in position where header start
    header = self.read(end_header - start_header)  # read all the byte between start and end of header
    header = header.decode("utf-8")  # Convertion byto to str, return header as string and not as byte
 
    return header
 
def header2dic(self):  # Create dictionary to store the value of the header
 
    HeaderDict = dict()  # create dictionary
    self = self.replace("\r\n", "")  # Remove the new line symbol to have a continuous string
    self = re.sub('"', '', self)  # Remove the " present in the string for easier info extraction
 
    list_header = re.findall(r'(\[.+?)(?=\[)',
                             self)  # Create list by regex from the header string. find the string between [ until [ NOW DOESNT RETRIEVE COMMENT SECTION (last) #Split the header according to its new line that are equal to the sections
 
    for line in list_header:
        sections = re.match(r'\[(?P<section>[^\]]+)\]', line).group(
            'section')  # Match the string between [] and return the section string(de one defined with ?P)
        informations = re.split(r'(?<=\d|"|\]|\w),(?=[a-zA-Z])', line)[
                       1:]  # Extract the information of the section removing the first which is the section name
 
        for info in informations:
            info = info.split("=")
            key1 = info[0]
            key2 = info[1]
            if key2.isdigit():
                key2 = int(key2)
 
            HeaderDict.update({key1: key2})
 
    return HeaderDict
 
def read_data(self, header, dictionary):
    offset = 64  # arbitrary value hope it stay the same
 
    if dictionary["BytesPerPixel"] == 4:  # Check the ammount of byte per pixel
        dtype = np.int32
    elif dictionary["BytesPerPixel"] == 2:
        dtype = np.int16
    else:
        print("Error! BytePerPixel Unknown")
 
    sizeImg = list(map(int, dictionary["areSource"].split(",")[
                            2:]))  # Get value in dict for the acquired area, split them, turn them in int and relist them
    msizeImg = reduce(lambda x, y: x * y, sizeImg)  # Multiply the above list and get the value in a list by themself
    sizeByte = dictionary["BytesPerPixel"]
 
    totalsizeImg = msizeImg * sizeByte  # Get the lenght to be read
    self.seek(offset + len(header))  # Place the reading position at the end of the header considering an arbitrary offset
 
    # totalsizeImg = msizeImg * sizeByte - skip_lines * sizeImg[0] * sizeByte# Get the lenght to be read
    # self.seek(
    #     offset + len(header) + skip_lines * sizeImg[0] * sizeByte)  # Place the reading position at the end of the header considering an arbitrary offset
    data_img = self.read(totalsizeImg)  # Read the ammount of data of the image
 
    data_img = np.frombuffer(data_img, dtype)  # Convert the data in np array
    # print(sizeImg[0], sizeImg[1])
    data_img = data_img.reshape(sizeImg[1], sizeImg[0])
    # data_img = data_img.reshape(sizeImg[1], sizeImg[0]- skip_lines)
 
    return data_img
 
def spectral_average(spect_psi, lifetime_psi, spect_psii, lifetime_psii):
    spect_av = np.array([mult1 * lifetime_psi for mult1 in spect_psi]) + np.array([mult2 * lifetime_psii for mult2 in spect_psii])
    return spect_av/max(spect_av)
def temporal_average(spect_psi, lifetime_psi, spect_psii, lifetime_psii, time, normfact):
    int_psi = sum(spect_psi)
    int_psii = sum(spect_psii)
    temp_av = []
    for t in time:
        temp_av.append(int_psi * math.exp(-t / lifetime_psi) + int_psii * math.exp(-t / lifetime_psii))
    return np.array(temp_av)/max(temp_av) * normfact
def add_blur(spect_psi, lifetime_psi, spect_psii, lifetime_psii, time, normfact, extend, sigma_val):
    temp_av = temporal_average(spect_psi, lifetime_psi, spect_psii, lifetime_psii, time, normfact)
    extend_blur = [0]*extend + list(temp_av)
    blurred = gaussian_filter(extend_blur, sigma=sigma_val)
    return blurred/max(blurred) * normfact
 
 
to_take_heim = 7
to_take_lifetime = 12
to_take = to_take_lifetime
peak_num = 3
peak_val1 = 130
 
 
#path = 'C:\\Users\\bos159\\OneDrive - Wageningen University & Research\\Streak data\\240725_ara_DCMU_per_wv\\430\\40uW_430nm_60gain_500ms_150exp_bgsh_240722_l_h.img'
#folder_path = filedialog.askdirectory()


#path= 'H:\\Data\\2024\\20240327_streak_photoinhibition_4h_linco\\Streak_data\\corrected_images\\20240327_05_Fm_plant3_dark_linco_25uW_450i_250ms_625nm_5nm_tr5.img'


#path= 'C:\\Users\\bagch002\\OneDrive - Wageningen University & Research\\Claudia - STREAK measurements/20241210_Claudia_mutants_tuesday\\15_ZEP_KO_Fm_DCMU_1sec_600ex_40uW_BG_SH.img'
#path= 'C:\\Users\\bagch002\\OneDrive - Wageningen University & Research\\Claudia - STREAK measurements/20241211_Claudia_mutants_wednesday\\03_WT_Fm_DCMU_700msec_300ex_40uW_no_greenlaser_BG_SH.img'
#path= 'C:\\Users\\bagch002\\OneDrive - Wageningen University & Research\\Claudia - STREAK measurements/20241210_Claudia_mutants_tuesday\\19_ZEP_L1L2_KO_189_stock_Fm_DCMU_1sec_600ex_40uW_BG_SH.img'
path= 'C:\\Users\\bagch002\\OneDrive - Wageningen University & Research\\Claudia - STREAK measurements/20250217_Claudia_mutants_monday\\20250217_05_ZEP_L1L2_KO_DCMU_1sec_600ex_40uW_BG_SH.img'




current_dir = str(os.path.dirname('/'.join(path.split('\\'))))
os.chdir(current_dir)
base_path = os.getcwd()
data_files = [path.split('\\')[-1]]
tt = data_files[0]
extracted_spectra = {}
for hh in data_files:
    filename = os.getcwd() + '\\' + hh
    with open(filename, "rb") as f:
        header = read_header(f)
        dict_header = header2dic(header)
        img = read_data(f, header, dict_header)
        xscale, yscale = scalingImage(filename, dict_header, np.shape(img)[0])
        img = (np.asarray([y[10:] for y in img.T]).T)
        yscale = yscale[10:]
extracted_spectra.update({hh: [img, xscale, yscale]})
wv_axis = extracted_spectra[tt][1]
dummy_t = extracted_spectra[tt][2] - extracted_spectra[tt][2][0]
test = extracted_spectra[tt][0]

#%% if you want to cutoff Fm to 3000

# index=np.argmax(dummy_t>3000)
# dummy_t=dummy_t[:index]
# test=test[:index,:]

#%% 
a_Fm = int(('%E' % max(np.average(test, axis=1))).split('E')[1])
Fm_factor = 10 ** a_Fm
x_axis = [x for x in np.arange(650, 820, 20) if min(wv_axis) <= x <= max(wv_axis)]
y_axis = [x for x in np.arange(0, 9001, 750) if min(dummy_t) <= x <= max(dummy_t)]
#x_axis= [640] + x_axis
selected_x_pos = []
for xx in x_axis:
    selected_x_pos.append([abs(f - xx) for f in wv_axis[::-1]].index(min([abs(f - xx) for f in wv_axis[::-1]])))
selected_y_pos = []
for xx in y_axis:
    selected_y_pos.append([abs(f - xx) for f in dummy_t].index(min([abs(f - xx) for f in dummy_t])))
 
 
wv_range = [640,820,20]
time_range = [0,int(dummy_t[-1]), 750]
time_range = list(range(time_range[0],time_range[1], time_range[2]))
wv_range = list(range(wv_range[0],wv_range[1],wv_range[2]))
numberwv = len(wv_range)
numbertime = len(time_range)
# timeser = np.array(dummy_t)-np.array(dummy_t[peak_temporal[peak_num]])
# timeser = np.linspace(timeser[0], timeser[-1], 500)
fig5 = plt.figure()
gs = GridSpec(1, 1, figure=fig5)
ax6 = fig5.add_subplot(gs[0, 0])
pos1=ax6.imshow([x[::-1] for x in test], cmap='turbo', vmin=np.amin(test), vmax=np.amax(test))

ax6.set_xticks(selected_x_pos, x_axis)
ax6.set_yticks(selected_y_pos, y_axis)
 
ax6.set_title('ZEP PsbS L1L2 KO - DCMU', fontsize=16, loc='center')
ax6.set_xlabel('Wavelength (nm)', fontsize=14)
ax6.set_ylabel('Time (ps)', fontsize=14)
ax6.tick_params(labelsize=12)
c1 = colorbar(pos1)
c1.set_label("Fluorescence Intensity (a.u.)", fontsize=14)
fig5.savefig(base_path+'/'+'ZEP_PsbS_L1L2_KO.png',dpi=1200)
fig5.savefig(base_path+'/'+'ZEP_PsbS_L1L2_KO.svg')
 
 
plt.show()

#%% figure in a square
# fig5 = plt.figure()
# gs = GridSpec(1, 1, figure=fig5)
# ax6 = fig5.add_subplot(gs[0, 0])
# pos1=ax6.imshow([x[::-1] for x in test], aspect=len(wv_axis)/len(dummy_t), cmap='turbo', vmin=np.amin(test), vmax=np.amax(test))

# ax6.set_xticks(selected_x_pos, x_axis)
# ax6.set_yticks(selected_y_pos, y_axis)
 
# ax6.set_title('Open RC - HL', fontsize=16, loc='center')
# ax6.set_xlabel('Wavelength (nm)', fontsize=14)
# ax6.set_ylabel('Time (ps)', fontsize=14)
# ax6.tick_params(labelsize=12)
# c1 = colorbar(pos1)
# c1.set_label("Fluorescence Intensity (a.u.)", fontsize=14)
# fig5.savefig(base_path+'/'+'open_RC_HL_new.png',dpi=1200)
# fig5.savefig(base_path+'/'+'open_RC_HL_new.svg')
 
 
# plt.show()