# -*- coding: utf-8 -*-
"""
Created on Wed Jan  5 10:03:49 2022

@author: francescvarkev
"""
from neuron import h
import matplotlib.pyplot as plt
import numpy as np
from WaveformModule import waveform
from scipy.integrate import simps
import pickle
import sys


h.load_file('stdgui.hoc')
h.load_file('interpxyz.hoc')

def init_model(config, S=1): #n=100, vinit=-65, Lm = 112, Dm = 4.23, Ln = 1, Dn = 3.385, S = 1, nseg = 5):
    # if 'n' not in config:
    #     config['n'] = 100
    
    # if 'vinit' not in config:
    #     config['vinit'] = -80

    # if 'Lm' not in config:
    #     config['Lm'] = 112

    # if 'Dm' not in config:
    #     config['Dm'] = 4.23
               
    h.v_init = config['vinit']
    h.celsius = config['celsius']
    n = config['nodes']
    
    for sec in h.allsec():
        h.delete_section(sec)
        
    Node = []
    Myelin = []
    for i in range(n):
        Node.append(h.Section(name=f'Node[{i}]'))
        if i != n:
            Myelin.append(h.Section(name=f'Myelin[{i}]'))
            Myelin[i].connect(Node[i])
        if i != 0:
            Node[i].connect(Myelin[i - 1])

    for sec in h.allsec():
        sec.insert('extracellular')
        sec.insert('pas')
        sec.insert('xtra')
        sec.Ra = config['Ra'] #100
        
    
    for sec in Myelin:
        sec.nseg = config['nseg']
        sec.L = config['L_myelin']*S #112
        sec.diam = config['diam_myelin']*S #4.23
        sec.cm = config['Cm_myelin'] #0.02
        
        
        for seg in sec:
            seg.pas.e = config['epas'] #-75
            seg.pas.g = config['gpas_myelin'] #8.888888888888889e-07
            seg.xtra.type = 3
            seg.xtra.order = 1
        
            
    for sec in Node:
       sec.insert('K_Pst')
       sec.insert('K_Tst')
       sec.insert('NaTa_t')
       sec.insert('Nap_Et2')
       sec.insert('SKv3_1')
       sec.cm = config['Cm_node'] #1
       sec.diam = config['diam_node']*S
       sec.L = config['L_node']
       
       for seg in sec:
           if sec.has_membrane('K_Pst'):
               seg.K_Pst.gK_Pstbar = config['gKPst'] #0.973538
           if sec.has_membrane('K_Tst'):
               seg.K_Tst.gK_Tstbar = config['gKTst'] #0.089259
           if sec.has_membrane('NaTa_t'):
               seg.NaTa_t.gNaTa_tbar = config['gNat'] #6.275936
           if sec.has_membrane('Nap_Et2'):
               seg.Nap_Et2.gNap_Et2bar = config['gNap'] #0.006827
           if sec.has_membrane('SKv3_1'):
               seg.SKv3_1.gSKv3_1bar = config['gSKv3_1'] #1.021945
           seg.xtra.type = 4
           seg.xtra.order = 1
           
           seg.k_ion.ek = config['ek'] #-85
           seg.na_ion.ena = config['ena'] #50
           seg.pas.e = config['epas'] #-75
           seg.pas.g = config['gpas_node'] #3e-5
           
    h.define_shape()
    h.grindaway()
    
    # xtra's ex must be coupled to e_extracellular in the same segment
    # and its im must be coupled to i_membrane in the same segment
    for sec in h.allsec():
        for seg in sec:
          # couple ex_xtra to e_extracellular
          # hoc syntax:  setpointer ex_xtra(x), e_extracellular(x)
          h.setpointer(sec(seg.x)._ref_e_extracellular, 'ex', sec(seg.x).xtra)
          # couple im_xtra to i_membrane
          # hoc syntax:  setpointer im_xtra(x), i_membrane(x)
          # h.setpointer(sec(seg.x)._ref_i_membrane, 'im', sec(seg.x).xtra)
    v = {}
    for sec in h.allsec():
        v[str(sec)] = h.Vector().record(sec(0.5)._ref_v)
        # es[sec] = h.Vector().record(sec(0.5)._ref_e_extracellular)
    tvec = h.Vector().record(h._ref_t)
    
    return Node, Myelin, v, tvec

def calcesI(x, y, sigma_e = 2.76e-07, bipolar = 0, phi = 0, dist = 10):
    if bipolar == 0:
        for sec in h.allsec():
            for seg in sec:
                r = np.sqrt((x - seg.x_xtra)**2 + y**2)
                seg.es_xtra = 1e-3/(4*np.pi*sigma_e*r)

    if bipolar == 1:
        xn = round((x - np.cos(np.deg2rad(phi))*dist/2), 1)
        xp = round((x + np.cos(np.deg2rad(phi))*dist/2), 1)
        yn = round((y - np.sin(np.deg2rad(phi))*dist/2), 1)
        yp = round((y + np.sin(np.deg2rad(phi))*dist/2), 1)
    
        for sec in h.allsec():
            if h.ismembrane("xtra", sec=sec):
                for seg in sec :
                    rn = np.sqrt((seg.x_xtra - xn)**2 +(seg.y_xtra - yn)**2)
                    rp = np.sqrt((seg.x_xtra - xp)**2 +(seg.y_xtra - yp)**2)
                    seg.es_xtra = (1e-3/sigma_e)*(1/(4*np.pi*rp) - 1/(4*np.pi*rn))
            
            


def AmpStep(w, v, Dur = None, Del = None, dt = None, depth = None, loc_find = None, biphasic=False, ipd=None, ratio=None, tau=None, sigma=None, guess=None):
    if Dur is None:
        Dur = 0.2
    
    if dt is None:
        dt = Dur/1e3
        
    if Del is None:
        Del = 1
    
    if depth is None:
        depth = -1
      
    if loc_find is None:
        loc_find = 1
    
    if ratio is None:
        ratio = 1
    
    dat = {}
    # global nex, ex, amp,t
    t0 = 0
    if biphasic is False:
        tstop = max([5, Del+Dur+4])
    else:
        if ipd is None:
            ipd = Dur/10
        if ratio is None:
            ratio = 1
        tstop = max([5, Del+(2+ratio)*Dur+ipd])
    
    if Dur < 1:
        h.dt = dt
    else:
        h.dt = dt*2
    excite = 0
    ex = 0 #last amplitude causing excitation
    nex = 0 #last amplitude failed to cause excitation
    if guess is None:
        amp = -10
    else:
        amp = guess
    # for i in np.logspace(1, depth, (2-depth)):
    # print(i)
    unit_wave, t = waveform(Dur, Del, 1, w, dt=dt, t0=t0, tstop=tstop, biphasic = biphasic, ipd = ipd, ratio = ratio)
    while abs(amp)<5000:
        print('amp = ' + str(amp))
        wave = np.multiply(unit_wave, amp)
        setStim(wave, t, dt, tstop) # define stimulation waveform 
        h.run()
        recdat = np.array(v['Node[0]'])
        if np.max(recdat) > 0:
            ex = amp
            amp = (nex + ex)/2
        elif np.max(recdat) <= 0:
            if ex == 0:
                nex = amp
                amp = amp*2
            else:
                nex = amp
                amp = (nex + ex)/2

        if np.abs(ex - nex) < (10**depth) and ex != 0:
            # print(amp)
            # print(np.max(recdat))
            excite = 1
            break  
        
    q = simps(wave, t)
    pwr = np.multiply(wave, wave)
    e = simps(pwr, t)
    if excite == 0:
        print('NO th found for ' + str(w) + ' at duration ' + str(Dur))
    else:
        print('th found for ' + str(w) + ' at duration ' + str(Dur))
    # plt.plot(t_vec, v)
    dat['shape'] = w
    dat['data'] = wave
    dat['q'] = q
    dat['pwr'] = pwr
    dat['e'] = e
    dat['t'] = t
    dat['dur'] = Dur
    dat['th'] = ex
    dat['ex'] = excite
    dat['dt'] = h.dt
    dat['biphasic'] = biphasic
    if biphasic is True:
        dat['ipd'] = ipd
        dat['ratio'] = ratio
    
    if loc_find == 1 and excite == 1:
        _ = setStim(unit_wave*ex, t, dt, tstop)
        h.run()
        v, t_ex_lst, loc, t_ex = findExLoc(v)
        dat['ex_loc'] = loc
        dat['t_onset'] = t_ex - Del
        dat['vm_ex'] = {}
        for l in loc:
            dat['vm_ex'][l] = v[l]
        dat['t_ex'] = t_ex_lst
    return dat

def SDdata(Dlist, w, rec, Del = None, dt = None, depth = None, x=200, y=50, bipolar=0, dist=200, theta=90, phi=0, loc_find=None, tau = None, sigma= None, biphasic=False, ipd=None, ratio=None):
    data = {}
    calcesI(x, y, bipolar = bipolar, dist = dist, phi = phi)
    excite = 1    
    guess = None
    for i in Dlist:
        i = round(i, 2)
        data[i] = AmpStep(w, rec, i, Del, dt, depth, loc_find, biphasic, ipd, ratio,  tau, sigma, guess)
        excite = data[i]['ex']
        if excite == 0: #if for one PW, not th can be found. Stop looking for th's at longer PWs (to save calculation time)
            data = {}
            break  
        guess = data[i]['th']
            
    return data

def setStim(data, t, dt, tstop): #Dur, Del, amp, w, dt, t0, tstop, tau, sigma, biphasic, ipd, ratio):
    if 'stim_amp' not in globals():
        global stim_amp 
        stim_amp = h.Vector() 
    if 'stim_time' not in globals():
        global stim_time
        stim_time = h.Vector()
    # data, t = waveform(Dur, Del, amp, w, dt=dt, t0=t0, tstop=tstop, tau=tau, sigma=sigma, biphasic=biphasic, ipd=ipd, ratio=ratio)
    h.dt = dt
    h.tstop = tstop
    stim_amp.resize(np.size(data))
    stim_amp.fill(0)
    stim_time.resize(np.size(t))
    stim_time.fill(0)
    for i in range(np.size(data)):
        stim_amp[i] = data[i]
        stim_time[i] = t[i]
    attach_stim()
    
    # return data, t

def attach_stim():
    # # now drive h.is_xtra
    stim_amp.play(h._ref_stim_xtra, stim_time, 1)

def findExLoc(rec): 
    """
    Function to find the location of activation. Uses the current stimulation set up. 
    Can be used after finding succesful stimulation parameters.
    Sets time and voltage record vectors for each section of the model and advances simulation until a 0 cross of Vm is detected at one section. This section is identified as excitation location.
    
    Parameters
    ----------
    plot : Bool, optional
        If true, the location is highlighted in the NEURON shapeplot shplot. The default is False.

    Returns
    -------
    v : dict
        dictionary containing the recorded Vm for each section.
    t_list : list
        time vector of recorded data.
    loc : string
        name of activated section.
    t_ex : float
        time of excitation (0 crossing).

    """
    t = []
    loc = []
    v = {}
    for vec in rec:
        v[vec] = rec[vec].to_python()
        if np.where(np.diff(np.sign(v[vec])))[0].size != 0:
            t.append(np.where(np.diff(np.sign(v[vec])))[0][0])
        
    t_ex = min(t)
    for vec in rec:
        if np.where(np.diff(np.sign(v[vec])))[0].size != 0:
            if np.where(np.diff(np.sign(v[vec])))[0][0] == t_ex:
                loc.append(vec)
    
    return v, t, loc, t_ex

def setConfig(Layer):
    
    config = {
        1 : {
        #General
        'vinit' : -80,
        'celsius' : 37,
        'nodes' : 101,
        'Ra' : 100,
        'epas' : -75,
        'ek' : -85,
        'ena' : 50,
        
        #Myelin
        'nseg' : 5,
        'L_myelin' : 40,
        'diam_myelin' : 0.72,
        'Cm_myelin' : 0.02,
        'gpas_myelin' : 8.89e-07,
        
        #Nodes
        'mechanisms' : ['K_Pst', 'K_Tst', 'NaTa_t', 'Nap_Et2', 'SKv3_1'],
        'Cm_node' : 1,
        'diam_node' : 0.478,
        'L_node' : 1,
        'gpas_node' : 8e-6,
        'gKPst' : 1.69e-03,
        'gKTst' : 4.21e-02,
        'gNat' : 8,
        'gNap' : 1e-06,
        'gSKv3_1' : 3.87e-01,
        },
        
        23 : {
        #General
        'vinit' : -80,
        'celsius' : 37,
        'nodes' : 101,
        'Ra' : 100,
        'epas' : -75,
        'ek' : -85,
        'ena' : 50,
        
        #Myelin
        'nseg' : 5,
        'L_myelin' : 55,
        'diam_myelin' : 1.051,
        'Cm_myelin' : 0.02,
        'gpas_myelin' : 8.89e-07,
        
        #Nodes
        'mechanisms' : ['K_Pst', 'K_Tst', 'NaTa_t', 'Nap_Et2', 'SKv3_1'],
        'Cm_node' : 1,
        'diam_node' : 0.766,
        'L_node' : 1,
        'gpas_node' : 3e-5,
        'gKPst' : 9.59e-01,
        'gKTst' : 1.04e-3,
        'gNat' : 6.86,
        'gNap' : 9.80e-03,
        'gSKv3_1' : 9.50e-02,
        },

        4 : {
        #General
        'vinit' : -80,
        'celsius' : 37,
        'nodes' : 101,
        'Ra' : 100,
        'epas' : -75,
        'ek' : -85,
        'ena' : 50,
        
        #Myelin
        'nseg' : 5,
        'L_myelin' : 49,
        'diam_myelin' : 1.028,
        'Cm_myelin' : 0.02,
        'gpas_myelin' : 8.89e-07,
        
        #Nodes
        'mechanisms' : ['K_Pst', 'NaTa_t', 'SKv3_1'],
        'Cm_node' : 1,
        'diam_node' : 0.751,
        'L_node' : 1,
        'gpas_node' : 6.3e-5,
        'gKPst' : 6.85e-02,
        'gKTst' : 0,
        'gNat' : 7.99,
        'gNap' : 0,
        'gSKv3_1' : 5.18e-01,
        },
        
        5 : {
        #General
        'vinit' : -80,
        'celsius' : 37,
        'nodes' : 101,
        'Ra' : 100,
        'epas' : -75,
        'ek' : -85,
        'ena' : 50,
        
        #Myelin
        'nseg' : 5,
        'L_myelin' : 59,
        'diam_myelin' : 1.247,
        'Cm_myelin' : 0.02,
        'gpas_myelin' : 8.89e-07,
        
        #Nodes
        'mechanisms' : ['K_Pst', 'K_Tst', 'NaTa_t', 'Nap_Et2', 'SKv3_1'],
        'Cm_node' : 1,
        'diam_node' : 0.93,
        'L_node' : 1,
        'gpas_node' : 3e-5,
        'gKPst' : 0.973538,
        'gKTst' : 0.089259,
        'gNat' : 6.275936,
        'gNap' : 0.006827,
        'gSKv3_1' : 1.021945,
        },
        
        6 : {
        #General
        'vinit' : -80,
        'celsius' : 37,
        'nodes' : 101,
        'Ra' : 100,
        'epas' : -75,
        'ek' : -85,
        'ena' : 50,
        
        #Myelin
        'nseg' : 5,
        'L_myelin' : 37,
        'diam_myelin' : 0.747,
        'Cm_myelin' : 0.02,
        'gpas_myelin' : 8.89e-07,
        
        #Nodes
        'mechanisms' : ['K_Pst', 'K_Tst', 'NaTa_t', 'Nap_Et2', 'SKv3_1'],
        'Cm_node' : 1,
        'diam_node' : 0.492,
        'L_node' : 1,
        'gpas_node' : 3e-5,
        'gKPst' : 9.57e-01,
        'gKTst' : 2.95e-02,
        'gNat' : 6.58,
        'gNap' : 6.71e-04,
        'gSKv3_1' : 1.94,
        }
    }
    
    return config.get(Layer, None)

def main(Layer, xloc_e, biphasic, ydist, Scale):
    dlist = np.append(np.arange(0.01, 0.5, 0.01), np.append(np.arange(0.5, 1, 0.1), np.array([1, 2, 5]))) #np.arange(0.05, 1.51, 0.05)
    
    config = setConfig(Layer)
    
    if config == None:
        print('Error: Invalid layer')
        sys.exit()
        
    Node, Myelin, v, t = init_model(config, Scale)
    wlist = ['r', 'g', 's', 't', 'ru', 'rd']
                                                                                                   
    if xloc_e == 0: # 0 = begin node electrode                                                                                      
        x0 = 0 # x-loc electrode                                                                                            
    else: # 1 = center electorde                                                                                                                       
        x0 = Node[int((config['nodes'] - 1)/2)].x_xtra/2
    
    
    dt = 0.0001
    data = {}
    
    print('Start')
    for w in wlist:
        data[w] = SDdata(dlist, w, v, dt = dt, depth = -2, x=x0, y=ydist, biphasic = biphasic) #, bipolar = bipolar, dist = 20, phi = 0)
    
    file = open(f'datafiles/data_{Layer}_{xloc_e}_{int(biphasic)}_{ydist}_{str(Scale).replace(".", "_")}.pkl', 'wb')
    pickle.dump(data, file)
    print('Done')

def mainw(Layer, w, xloc_e, biphasic, ydist, Scale):
    dlist = [1] # np.append(np.arange(0.01, 0.5, 0.01), np.append(np.arange(0.5, 1, 0.1), np.array([1, 2, 5]))) #np.arange(0.05, 1.51, 0.05)
    
    config = setConfig(Layer)
    
    if config == None:
        print('Error: Invalid layer')
        sys.exit()
        
    Node, Myelin, v, t = init_model(config, Scale)
                                                                                                   
    if xloc_e == 0: # 0 = begin node electrode                                                                                      
        x0 = 0 # x-loc electrode                                                                                            
    else: # 1 = center electorde                                                                                                                       
        x0 = Node[int((config['nodes'] - 1)/2)].x_xtra
    
    
    dt = 0.0001
    data = {}
    
    # print('Start')
    data[w] = SDdata(dlist, w, v, dt = dt, depth = -2, x=x0, y=ydist, biphasic = biphasic) #, bipolar = bipolar, dist = 20, phi = 0)
    
    # file = open(f'datafiles/data_{Layer}_{w}_{xloc_e}_{int(biphasic)}_{ydist}_{str(Scale).replace(".", "_")}.pkl', 'wb')
    # pickle.dump(data, file)
    print('Done')
    return data      
        
# if __name__ == '__main__':
    #get sys args:
    # Layer = int(sys.argv[1])
    # xloc_e = int(sys.argv[2])
    # biphasic = bool(int(sys.argv[3]))
    # ydist = int(sys.argv[4]) 
    # Scale = float(sys.argv[5])
    # bipolar = 0 #int(sys.argv[2])
    
    

# attach_stim()
# setStim(1, 1, 0, 'r', 0.001, 0, 50, 0, 0, False, 0, 0)
# h.run()
# for sec in h.allsec():
#     plt.plot(tvec, v[sec])
