# -*- coding: utf-8 -*-
"""
Created on Tue Oct 12 11:10:12 2021

@author: francescvarkev
"""
import pickle
import os
from neuron import h
import numpy as np
from model import initCell
from scipy.optimize import curve_fit
from WaveformModule import waveform
from scipy.integrate import simps
from time import time

def loadPickle(filename):
    a_file = open(filename, "rb")
    output = pickle.load(a_file)
    a_file.close()
    
    return output


def calcOptParams(data, meta, cell_id):
    from model import calcr
    initCell(cell_id)
    opt = {}
    loclist = meta['loclist']
    wlist = meta['shapelist']
    for loc in loclist:
        min_e = 1e6
        opt[loc] = {}
    
        for w in wlist:
            pw = np.where(data[cell_id][loc][w][0][0]['e'] == min(data[cell_id][loc][w][0][0]['e']))[0][0]
            opt[loc]['pw_'+str(w)] = data[cell_id][loc][w][0][0]['d'][pw]
            if min(data[cell_id][loc][w][0][0]['e']) < min_e:
                    min_e = min(data[cell_id][loc][w][0][0]['e'])
                    opt[loc]['w'] = w

        opt[loc]['e'] = min(data[cell_id][loc][opt[loc]['w']][0][0]['e'])
        i = np.where(data[cell_id][loc][opt[loc]['w']][0][0]['e'] == opt[loc]['e'])[0][0]
        opt[loc]['pw'] = data[cell_id][loc][opt[loc]['w']][0][0]['d'][i]
        opt[loc]['th'] = np.abs(data[cell_id][loc][opt[loc]['w']][0][0]['s'][i])
        opt[loc]['q'] = data[cell_id][loc][opt[loc]['w']][0][0]['q'][i]
        opt[loc]['ex_loc'] = []
        for ex_loc in data[cell_id][loc][opt[loc]['w']][0][0]['SD'][opt[loc]['pw']]['ex_loc']:
            if ex_loc[0] == 'N':    
                opt[loc]['ex_loc'] = ex_loc
                opt[loc]['Node'] = 1
        if opt[loc]['ex_loc'] == []:
            opt[loc]['ex_loc'] = data[cell_id][loc][opt[loc]['w']][0][0]['SD'][opt[loc]['pw']]['ex_loc'][0]
            opt[loc]['Node'] = 0    
        sec = eval('h.' + opt[loc]['ex_loc'])
        opt[loc]['ex_dist'] = calcr(sec, loc[0], loc[1], loc[2])
        opt[loc]['diam'] = sec.diam
        
        if sec.children() == []:
            opt[loc]['end'] = 1
        else:
            opt[loc]['end'] = 0
        if opt[loc]['Node'] == 1:
            opt[loc]['L'] = sec.parentseg().sec.L
            opt[loc]['par_diam'] = sec.parentseg().sec.diam
        opt[loc]['min_dist'] = data[cell_id][loc][opt[loc]['w']][0][0]['dist']
    
    return opt

def renameCell(data, meta): # parallel process intantiates different instances of the cell. Replace all instances with the same cell name.
    loclist = meta['loclist']
    durlist = meta['dlist']
    wlist = meta['shapelist']
    cell = meta['cell']
    
    for loc in loclist:
        for w in wlist:
            for dur in durlist:
                tmp_list = []
                for sec in data[cell_id][loc][w][0][0]['SD'][dur]['ex_loc']:
                    if sec[0:len(cell[0:-3])] == cell[0:-3]:
                        tmp = sec.split('.')
                        tmp[0] = cell
                        tmp = '.'.join(tmp)
                        tmp_list.append(tmp)
                    else:
                        tmp_list.append(sec)
                data[cell_id][loc][w][0][0]['SD'][dur]['ex_loc'] = tmp_list
               
    return data

def calcChronaxie(data, meta):
    durlist = meta['dlist']
    wlist = meta['shapelist']
    loclist = meta['loclist']
    
    xinterp = np.arange(durlist[0], durlist[-1], durlist[0]/100)
    for w in wlist:
        for loc in loclist:
            data[cell_id][loc][w][0][0]['sinterp'] = np.interp(xinterp, data[cell_id][loc][w][0][0]['d'], data[cell_id][loc][w][0][0]['s'])
            data[cell_id][loc][w][0][0]['chronaxie'] = xinterp[np.where(np.abs(data[cell_id][loc][w][0][0]['sinterp']) > 2*min(np.abs(data[cell_id][loc][w][0][0]['s'])))[0][-1]]
            data[cell_id][loc][w][0][0]['s_chronaxie'] = 2*min(np.abs(data[cell_id][loc][w][0][0]['s']))
    return data

# def It_function(x, I0, tau): # define function to be fitted
#     return I0*1/(1-np.exp(-x/tau))

def It_function(x, I0, tau):
    return I0*(1+tau/x)

def curveFitSD(data, meta):
    # Use curvefit to determine timeconstant and reobase current from SD curves.

    wlist = meta['shapelist']
    durlist = meta['dlist']
    loclist = meta['loclist']
    cell_id = int(meta['cell_id'])
    
    dt = meta['dt']
    
    xdurlist = np.arange(durlist[0], durlist[-1], durlist[0])
    
    for w in wlist:
        for loc in loclist:
            Ith = np.abs(data[cell_id][loc][w][0][0]['s'])
            # plt.figure()
            # plt.scatter(durlist, Ith, s=20) #scatter raw data
            if np.where(np.abs(data[cell_id][loc][w][0][0]['s']) > 2*np.abs(data[cell_id][loc][w][0][0]['s'][-1]))[0].size == 0:
                i = 0
            else:
                i = min([np.where(np.abs(data[cell_id][loc][w][0][0]['s']) > 2*np.abs(data[cell_id][loc][w][0][0]['s'][-1]))[0][-1], 5])
            
            pars, cov = curve_fit(f=It_function, xdata=durlist[i:], ydata=Ith[i:], p0=[Ith[-1], 0.2], bounds=(-np.inf, np.inf)) # fit curve to data
            stdevs = np.sqrt(np.diag(cov)) # calculate std deviations for parameters a and b
            res = Ith - It_function(np.array(durlist), *pars) # error values

            I0 = pars[0]
            tau = pars[1]
            data[cell_id][loc][w][0][0]['fit'] = {}
            data[cell_id][loc][w][0][0]['fit']['I0'] = I0
            data[cell_id][loc][w][0][0]['fit']['tau'] = tau
            data[cell_id][loc][w][0][0]['fit']['stdevs'] = stdevs
            data[cell_id][loc][w][0][0]['fit']['res'] = res
            data[cell_id][loc][w][0][0]['fit']['i'] = i
            
            E_fit = []
            S_fit = []
            Q_fit = []
            for pw in xdurlist:
                amp = It_function(pw, I0, tau)
                S_fit.append(amp)
                wave, t = waveform(pw, 1, amp, w, dt)  
                q = simps(wave, t)
                Q_fit.append(q)
                pwr = np.multiply(wave, wave)
                energy = simps(pwr, t)
                E_fit.append(energy)
            
            data[cell_id][loc][w][0][0]['fit']['s'] = S_fit
            data[cell_id][loc][w][0][0]['fit']['e'] = E_fit
            data[cell_id][loc][w][0][0]['fit']['q'] = Q_fit
            data[cell_id][loc][w][0][0]['fit']['d'] = xdurlist
        print(w)
            
    return data

if __name__ == '__main__':
    tic = time()
    ### Load raw data and meta files ###
    cell_id = input('Cell id: ')
    cell_id = int(cell_id)
    data = loadPickle('Datafiles/' + str(cell_id) + '/Data.pkl')
    meta = loadPickle('Datafiles/' + str(cell_id) + '/meta.pkl')    

    ### Do some conversion operations ###
    # meta['dlist'] = data[cell_id][meta['loclist'][0]][meta['shapelist'][0]][0][0]['d']
    data = renameCell(data, meta)
    opt = calcOptParams(data, meta, cell_id)
    data = calcChronaxie(data, meta)
    t_int1 = time()
    data = curveFitSD(data, meta) 
    t_int2 = time() - t_int1
    print(t_int2)
    ### Save raw Data and meta files ###
    a_file = open("Datafiles/" + str(cell_id) + '/Data_raw.pkl', "wb")
    pickle.dump(data, a_file)
    a_file.close()
    
    a_file = open("Datafiles/" + str(cell_id) + '/meta.pkl', "wb")
    pickle.dump(meta, a_file)
    a_file.close()
    
    a_file = open("Datafiles/" + str(cell_id) + '/opt.pkl', "wb")
    pickle.dump(opt, a_file)
    a_file.close()
    
    ### Remove 'SD' entry from data ###
    for key in data[cell_id]:
        for key2 in data[cell_id][key]:
            data[cell_id][key][key2][0][0].pop('SD', None)
    
    ### Save simple version of Data ###
    a_file = open("Datafiles/" + str(cell_id) + '/Data_simple.pkl', "wb")
    pickle.dump(data, a_file)
    a_file.close()
    toc = time() - tic
    print(toc)
        

