# -*- coding: utf-8 -*-
"""
Created on Thu Jul 15 11:20:14 2021

@author: francescvarkev
"""
from neuron import h
import numpy as np
from matplotlib import pyplot as plt
from WaveformModule import waveform
from scipy.integrate import simps
from neuron.units import ms, mV
import os, sys
from calcVe import calcesI
from model import findExLoc
import collections
from statistics import mean
from time import time

# dat = {}

#def setStim(Dur, Del, amp, w, dt, t0, tstop, tau, sigma, biphasic, ipd, ratio):
#      data, t = waveform(Dur, Del, amp, w, dt=dt, t0=t0, tstop=tstop, tau=tau, sigma=sigma, biphasic=biphasic, ipd=ipd, ratio=ratio)
#      h.stim_amp.resize(np.size(data))
#      h.stim_amp.fill(0)
#      h.stim_time.resize(np.size(t))
#      h.stim_time.fill(0)
#      for i in range(np.size(data)):
#          h.stim_amp[i] = data[i]
#          h.stim_time[i] = t[i]
#      return data, t

def setStim(data, t, dt, tstop):
      h.dt = dt
      h.tstop = tstop
      h.stim_amp.resize(np.size(data))
      h.stim_amp.fill(0)
      h.stim_time.resize(np.size(t))
      h.stim_time.fill(0)
      for i in range(np.size(data)):
          h.stim_amp[i] = data[i]
          h.stim_time[i] = t[i]
 
def runSim(v_soma):
    h.run()
    recdat = np.array(v_soma)
    
    return recdat

# def AmpStep(w, Dur = 0.2, Del = 1, dt = 0.01, depth = -2):
#     soma = h.cell.soma[0]
#     v = h.Vector().record(soma(0.5)._ref_v) # create voltage record vector at soma
#     t_vec = h.Vector().record(h._ref_t) # create time record vector
#     dat = {}
    
#     t = np.linspace(0, 5, 50000)
#     with HiddenPrints():
#         h.dt = dt
#     excite = 0
#     amp = 0
#     for i in np.logspace(1, depth, (2-depth)):
#         # print(i)
#         while amp<100:
#             wave = setStim(t, Del, Dur, amp, w) # define stimulation waveform 
#             h.run()
#             recdat = np.array(v)
#             if np.max(recdat) > 0:
#                 if i != 10**depth:
#                     amp = amp + i
#                 else:
#                     excite = 1
#                 break
#             amp = amp - i
        
#     q = simps(wave, t)
#     pwr = np.multiply(wave, wave)
#     e = simps(pwr, t)
#     if excite == 1:
#         print('th found for ' + str(w) + ' at duration ' + str(Dur))
#     # plt.plot(t_vec, v)
#     dat['shape'] = w
#     dat['data'] = wave
#     dat['q'] = q
#     dat['pwr'] = pwr
#     dat['e'] = e
#     dat['t'] = t
#     dat['dur'] = Dur
#     dat['th'] = round(amp, -depth)
#     dat['ex'] = excite
    
#     return dat

def AmpStep(w, rec, Dur = None, Del = None, dt = None, depth = None, loc_find = None, biphasic=False, ipd=None, ratio=None, tau=None, sigma=None, guess=None):
    if Dur is None:
        Dur = 0.2
    
    if dt is None:
        dt = Dur/1e3
        
    if Del is None:
        Del = 1
    
    if depth is None:
        depth = -1
      
    if loc_find is None:
        loc_find = 1
    
    if ratio is None:
        ratio = 1
    
    soma = h.cell.soma[0]
    v = rec['voltage'][str(soma)] # create voltage record vector at soma
    dat = {}
    # global nex, ex, amp,t
    t0 = 0
    if biphasic is False:
        tstop = max([5, Del+Dur+4])
    else:
        if ipd is None:
            ipd = Dur/10
        if ratio is None:
            ratio = 1
        tstop = max([5, Del+(2+ratio)*Dur+ipd])
    
    if Dur < 1:
        h.dt = dt
    else:
        h.dt = dt*2
    excite = 0
    ex = 0 #last amplitude causing excitation
    nex = 0 #last amplitude failed to cause excitation
    if guess is None:
        amp = -50
    else:
        amp = guess
    # for i in np.logspace(1, depth, (2-depth)):
    # print(i)
    unit_wave, t = waveform(Dur, Del, 1, w, dt=dt, t0=t0, tstop=tstop, biphasic = biphasic, ipd = ipd, ratio = ratio)
    while abs(amp)<5000:
        print('amp = ' + str(amp))
        wave = np.multiply(unit_wave, amp)
        setStim(wave, t, dt, tstop) # define stimulation waveform 
        h.run()
        recdat = np.array(v)
        if np.max(recdat) > 0:
            ex = amp
            amp = (nex + ex)/2
        elif np.max(recdat) <= 0:
            if ex == 0:
                nex = amp
                amp = amp*2
            else:
                nex = amp
                amp = (nex + ex)/2

        if np.abs(ex - nex) < (0.01*np.abs(ex)) and ex != 0:
            # print(amp)
            # print(np.max(recdat))
            excite = 1
            break  
        
    q = simps(wave, t)
    pwr = np.multiply(wave, wave)
    e = simps(pwr, t)
    if excite == 0:
        print('NO th found for ' + str(w) + ' at duration ' + str(Dur))
    else:
        print('th found for ' + str(w) + ' at duration ' + str(Dur))
    # plt.plot(t_vec, v)
    dat['shape'] = w
    dat['data'] = wave
    dat['q'] = q
    dat['pwr'] = pwr
    dat['e'] = e
    dat['t'] = t
    dat['dur'] = Dur
    dat['th'] = ex
    dat['ex'] = excite
    dat['dt'] = h.dt
    dat['biphasic'] = biphasic
    if biphasic is True:
        dat['ipd'] = ipd
        dat['ratio'] = ratio
    
    if loc_find == 1 and excite == 1:
        _ = setStim(unit_wave*ex, t, dt, tstop)
        h.run()
        v, t_ex_lst, loc, t_ex = findExLoc(rec)
        dat['ex_loc'] = loc
        dat['t_onset'] = t_ex - Del
        dat['vm_ex'] = {}
        for l in loc:
            dat['vm_ex'][l] = v[l]
        dat['t_ex'] = t_ex_lst
    return dat  
#    while abs(amp)<5000:
#        print('amp = ' + str(amp))
#        wave, t = setStim(Dur, Del, amp, w, dt, t0, tstop, tau, sigma, biphasic, ipd, ratio) # define stimulation waveform 
#        h.run()
#        recdat = np.array(v)
#        if np.max(recdat) > 0:
#            ex = amp
#            amp = (nex + ex)/2
#        elif np.max(recdat) <= 0:
#            if ex == 0:
#                nex = amp
#                amp = amp*2
#            else:
#                nex = amp
#                amp = (nex + ex)/2
#
#        if np.abs(ex - nex) < (0.01*np.abs(ex)) and ex != 0:
#            # print(amp)
#            # print(np.max(recdat))
#            excite = 1
#            break  
#        
#    q = simps(wave, t)
#    pwr = np.multiply(wave, wave)
#    e = simps(pwr, t)
#    if excite == 0:
#        print('NO th found for ' + str(w) + ' at duration ' + str(Dur))
#    else:
#        print('th found for ' + str(w) + ' at duration ' + str(Dur))
#    # plt.plot(t_vec, v)
#    dat['shape'] = w
#    dat['data'] = wave
#    dat['q'] = q
#    dat['pwr'] = pwr
#    dat['e'] = e
#    dat['t'] = t
#    dat['dur'] = Dur
#    dat['th'] = ex
#    dat['ex'] = excite
#    dat['dt'] = h.dt
#    dat['biphasic'] = biphasic
#    if biphasic is True:
#        dat['ipd'] = ipd
#        dat['ratio'] = ratio
#    
#    if loc_find == 1 and excite == 1:
#        _ = setStim(Dur, Del, ex, w, dt, t0, tstop, tau, sigma, biphasic, ipd, ratio)
#        h.run()
#        v, t_ex_lst, loc, t_ex = findExLoc(rec)
#        dat['ex_loc'] = loc
#        dat['t_onset'] = t_ex - Del
#        dat['vm_ex'] = {}
#        for l in loc:
#            dat['vm_ex'][l] = v[l]
#        dat['t_ex'] = t_ex_lst
#    return dat

# def FixedAmp(w, rec, amp, Dur = 0.2, Del = 1, dt = 0.0125, depth = -2, loc_find = 1, tau=0.4, sigma=0.1, guess=None):
#     soma = h.cell.soma[0]
#     v = rec['voltage'][str(soma)] # create voltage record vector at soma
#     dat = {}
#     # global nex, ex, amp,t
#     a = 0
#     b = 5
#     n = int((b-a)/dt)
#     t = np.linspace(a, b, n)
#     h.dt = dt
#     excite = 0
#     wave = setStim(t, Del, Dur, amp, w, tau, sigma) # define stimulation waveform 
#     h.run()
#     recdat = np.array(v)
#     if np.max(recdat) > 0:
#         excite = 1
        
#     q = simps(wave, t)
#     pwr = np.multiply(wave, wave)
#     e = simps(pwr, t)
#     if excite == 0:
#         print('NO th found for ' + str(w) + ' at duration ' + str(Dur))
#     else:
#         print('th found for ' + str(w) + ' at duration ' + str(Dur))
#     # plt.plot(t_vec, v)
#     dat['shape'] = w
#     dat['data'] = wave
#     dat['q'] = q
#     dat['pwr'] = pwr
#     dat['e'] = e
#     dat['t'] = t
#     dat['dur'] = Dur
#     dat['th'] = amp
#     dat['ex'] = excite
    
#     if loc_find == 1 and excite == 1:
#         _ = setStim(t, Del, Dur, amp, w, tau, sigma)
#         v, t_ex_lst, loc, t_ex = findExLoc(rec)
#         dat['ex_loc'] = loc
#         dat['t_onset'] = t_ex - Del
#         dat['vm_ex'] = {}
#         for l in loc:
#             dat['vm_ex'][l] = v[l]
#         dat['t_ex'] = t_ex_lst
#     return dat

def SDdata(Dlist, w, rec, Del = None, dt = None, depth = None, x=200, y=-50, z=0, bipolar=0, dist=200, theta=90, phi=0, loc_find=None, tau = None, sigma= None, biphasic=False, ipd=None, ratio=None):
    data = {}
    calcesI(x, y, z, bipolar=bipolar, dist=dist, theta=theta, phi=phi)
    excite = 1
    # slope = 1
    # amp = 0
    # amp_interp = 0
    # buf = collections.deque(maxlen=3)
    # if bipolar == 0:
    #     pol = 'monopolar'
    # else:
    #     pol = 'bipolar'
    
    guess = None
    for i in Dlist:
        i = round(i, 2)
        data[i] = AmpStep(w, rec, i, Del, dt, depth, loc_find, biphasic, ipd, ratio,  tau, sigma, guess)
        excite = data[i]['ex']
        if excite == 0: #if for one PW, not th can be found. Stop looking for th's at longer PWs (to save calculation time)
            guess = -4999
        else: 
            guess = data[i]['th']

    # interp = 0
    # for i in np.arange(Dmin, np.round(Dmax+Dstep, 5), Dstep):
    #     if interp == 0:
    #         i = round(i, 2)
            
    #         data[i] = AmpStep(w, rec, i, Del, dt, depth, tau=tau, sigma = sigma, guess=guess, loc_find = loc_find)
    #         data[i]['interp'] = interp
    #         excite = data[i]['ex']
    #         if excite == 0: #if for one PW, not th can be found. Stop looking for th's at longer PWs (to save calculation time)
    #             data = {}
    #             break  
    #         guess = data[i]['th']
    #         buf.append(data[i]['th'])
    #         print(buf)
    #         print(mean(buf)/data[i]['th'])
            
    #         if len(buf) == 3 and np.abs(mean(buf)) <= np.abs(data[i]['th']*interp_th):
    #             print('interpolation')
    #             interp = 1
    #             tmp = round(np.arange(Dmin, np.round(Dmax+Dstep, 5), Dstep)[-1],2)
    #             amp_end = AmpStep(w, rec, tmp, Del, dt, depth, tau=tau, sigma = sigma, guess=guess, loc_find = False)['th']
    #             amp = data[i]['th']
    #             slope = (amp_end - amp)/(tmp-i)
    #     else:
    #        i = round(i, 2) 
    #        amp = amp + slope*Dstep
    #        data[i] = FixedAmp(w, rec, amp, i, Del, dt, depth, tau=tau, sigma = sigma, guess=guess, loc_find = loc_find)
    #        data[i]['interp'] = interp
            
    return data


# if __name__ == '__main__':
#     from model import createRecordingDict, initCell
#     initCell(20)
#     rec = createRecordingDict()
#     tic = time()
#     data = {}
#     d = []
#     e = []
#     s = []
#     Dlist = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.25, 1.5, 2.5, 5]
#     data = SDdata(Dlist, 'g', rec, x=-700, y=1100, z=0, dt = 0.005, interp_th=1)
    
#     for dur in data:
#         d.append(dur)
#         e.append(data[dur]['e'])
#         s.append(data[dur]['th'])
#     data['e'] = e
#     data['d'] = d
#     data['s'] = s
#     toc = time() - tic
#     print(toc)
    
    # tic = time()
    # data_lin = {}
    # d = []
    # e = []
    # s = []
    # Dlist = np.arange(0.05, np.round(1.55, 5), 0.05)
    # data_lin = SDdata(Dlist, 'g', rec, x=-700, y=1100, z=0, dt = 0.005, interp_th=1.05)
    
    # for dur in data_lin:
    #     d.append(dur)
    #     e.append(data_lin[dur]['e'])
    #     s.append(data_lin[dur]['th'])
    # data_lin['e'] = e
    # data_lin['d'] = d
    # data_lin['s'] = s
    # toc = time() - tic
    # print(toc)
    
    # plt.plot(data_lin['d'], np.abs(data_lin['s']))
    # plt.plot(data['d'], np.abs(data['s']))
    
    # plt.figure()
    
    # plt.plot(data_lin['d'], np.abs(data_lin['e']))
    # plt.plot(data['d'], np.abs(data['e']))
