# -*- coding: utf-8 -*-
"""
Created on Wed Jan 12 15:04:29 2022

@author: francescvarkev
"""
import pickle
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import simps
import itertools
from WaveformModule import waveform
from bisect import bisect_left
import matplotlib as mpl


mpl.style.use('seaborn-whitegrid')
plt.rcParams['legend.frameon'] = True
plt.rcParams['legend.edgecolor'] = 'white'
plt.rcParams['font.size'] = 14
plt.rcParams['axes.labelsize'] = 16
plt.rcParams['legend.fontsize'] = 14
# mpl.style.use('ggplot')
savefigs = False
marker = itertools.cycle(('o',"s", "|", "x", "^", "v")) 
# biphasic = 0
# bipolar = 0
# loc = 0
# filename = f'datafiles/data_{biphasic}_{bipolar}_{loc}_80mV_2.pkl' #args: Biphasic, Bipolar, Location (0: x=0, 1: x=center)    
Layer = 5
xloc_e = 1
biphasic = 0
ydist = 100
Scale = 1

#%%
filename = f'datafiles/data_{Layer}_{xloc_e}_{int(biphasic)}_{ydist}_{str(Scale).replace(".", "_")}.pkl'
with open(filename, 'rb') as f:
    data = pickle.load(f)
    
    
legendlist = ['Rectangular', 'Gaussian', 'Half-Sine', 'Triangular', 'Ramp-Up', 'Ramp-Down']
wlist = list(data)
dur = list(data['r'])
if biphasic == 1:
    for w in wlist:
        for d in dur:
            wave = data[w][d]['data']
            t = data[w][d]['t']
            data[w][d]['q'] = -1*simps(np.abs(wave), t)
th = {}
q = {}
e = {}
econst = {}
for w in wlist:
    th[w] = [data[w][d]['th'] for d in dur]
    e[w] = [data[w][d]['e'] for d in dur]
    q[w] = [data[w][d]['q'] for d in dur]
    econst[w] = [th[w][i]*q[w][i] for i in range(len(dur))]

#%%    
plt.figure() 
for w in wlist:
    plt.plot(dur, np.abs(th[w]), label = w, marker = next(marker), fillstyle='none', markevery=10)
plt.legend(legendlist)
# plt.title(f'S/D {biphasic} {bipolar} {loc}')
# plt.title(f'S/D {Layer} {xloc_e} {ydist}')
plt.ylim([0.9*min(np.abs(th['r'])), 5*min(np.abs(th['r']))])
plt.xlim([0, 1])
plt.xlabel('Pulse Width [ms]')
plt.ylabel('$|I_{th}|$ [$\mu$A]')
plt.tight_layout()
if savefigs is True:
    plt.savefig(f'figures/data_{Layer}_{xloc_e}_{int(biphasic)}_{ydist}_{str(Scale).replace(".", "_")}_SD.eps')
    plt.savefig(f'figures/data_{Layer}_{xloc_e}_{int(biphasic)}_{ydist}_{str(Scale).replace(".", "_")}_SD.png')

plt.figure() 
for w in wlist:
    plt.plot(dur, np.abs(q[w]), label = w, marker = next(marker), fillstyle='none', markevery=10)
plt.legend(legendlist)
# plt.title(f'Q/D {biphasic} {bipolar} {loc}')
# minQ = min(min(np.abs(q[w])) for w in wlist)
# plt.ylim([0.8*minQ, 2*minQ])
# plt.title(f'Q/D {Layer} {xloc_e} {ydist}')
plt.xlim([0, 1])
plt.ylim([0, np.abs(q['r'][54])*1.2])
plt.xlabel('Pulse Width [ms]')
plt.ylabel('$|Q_{th}|$ [nC]')
plt.tight_layout()
if savefigs is True:
    plt.savefig(f'figures/data_{Layer}_{xloc_e}_{int(biphasic)}_{ydist}_{str(Scale).replace(".", "_")}_QD.eps')
    plt.savefig(f'figures/data_{Layer}_{xloc_e}_{int(biphasic)}_{ydist}_{str(Scale).replace(".", "_")}_QD.png')

plt.figure() 
for w in wlist:
    plt.plot(dur, e[w], label = w, marker = next(marker), fillstyle='none', markevery=10)
plt.legend(legendlist)
# plt.title(f'E/D {biphasic} {bipolar} {loc}')
# plt.title(f'E/D {Layer} {xloc_e} {ydist}')
# minE = min(min(np.abs(e[w])) for w in wlist)
# plt.ylim([0.8*minE, 2*minE])
minE = min(e['r'])
plt.ylim([0.8*minE, 2*minE])
plt.xlim([0, 1])
plt.xlabel('Pulse Width [ms]')
plt.ylabel('$E_{adaptive}$ [pJ/k$\Omega$]')
plt.tight_layout()
if savefigs is True:
    plt.savefig(f'figures/data_{Layer}_{xloc_e}_{int(biphasic)}_{ydist}_{str(Scale).replace(".", "_")}_EDadaptive.eps')
    plt.savefig(f'figures/data_{Layer}_{xloc_e}_{int(biphasic)}_{ydist}_{str(Scale).replace(".", "_")}_EDadaptive.png')

print('Eadaptive:')
for w in wlist:
    print(f'{w}: {round(((min(e[w])/min(e["r"]) -1) * 100), 2)}')
    
plt.figure() 
for w in wlist:
    plt.plot(dur, econst[w], label = w, marker = next(marker), fillstyle='none', markevery=10)
plt.legend(legendlist)
# plt.title(f'Econst/D {biphasic} {bipolar} {loc}')
# minEc = min(min(np.abs(econst[w])) for w in wlist)
minEc = min(econst['r'])
plt.ylim([0.8*minEc, 2*minEc])
plt.xlim([0, 1])
plt.xlabel('Pulse Width [ms]')
plt.ylabel('$E_{constant}$ [pJ/k$\Omega$]')
plt.tight_layout()
if savefigs is True:
    plt.savefig(f'figures/data_{Layer}_{xloc_e}_{int(biphasic)}_{ydist}_{str(Scale).replace(".", "_")}_EDconstant.eps')
    plt.savefig(f'figures/data_{Layer}_{xloc_e}_{int(biphasic)}_{ydist}_{str(Scale).replace(".", "_")}_EDconstant.png')
           
# plt.title(f'Econst/D {Layer} {xloc_e} {ydist}')

plt.figure() 
for w in wlist:
    plt.plot(np.abs(q[w][:-2]), np.abs(th[w][:-2]), label = w, marker = next(marker), fillstyle='none', markevery=10)
plt.legend(legendlist)
# plt.title(f'Econst/D {biphasic} {bipolar} {loc}')
# minEc = min(min(np.abs(econst[w])) for w in wlist)
plt.ylim([0.9*min(np.abs(th['r'])), 7*min(np.abs(th['r']))])
plt.xlabel('$|Q_{th}|$ [nC]')
plt.ylabel('$|I_{th}|$ [$\mu$A]')
plt.tight_layout()
if savefigs is True:
    plt.savefig(f'figures/data_{Layer}_{xloc_e}_{int(biphasic)}_{ydist}_{str(Scale).replace(".", "_")}_IQ.eps')
    plt.savefig(f'figures/data_{Layer}_{xloc_e}_{int(biphasic)}_{ydist}_{str(Scale).replace(".", "_")}_IQ.png')
           
# plt.title(f'Econst/D {Layer} {xloc_e} {ydist}')

print('Econst:')
for w in wlist:
    print(f'{w}: {round(((min(econst[w])/min(econst["r"]) - 1 ) * 100 ), 2)}')
    
#%%
I, t = waveform(1, 0, 1, 's', dt=0.001, t0=-0.2, tstop=1.2)
plt.plot(t, I, linewidth=4)
plt.xticks([0, 1], ['0', 'PW'], size=18)
plt.yticks([0, 1], size=18)
plt.xlim([-0.2, 1.2])
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.ylim([0, 1.2])
plt.ylabel('Current [A]', size=20)
if savefigs is True:
    plt.savefig('figures/Cpulse.eps')
V, t = waveform(2, -0.5, 1, 'r', dt=0.001, t0=-0.2, tstop=1.2)

Vlevels = [(i)/3 for i in range(4)] #generate 4 equally spaced voltage levels
Vsteps = [Vlevels[bisect_left(Vlevels, x)] for x in I]

plt.figure()
plt.plot(t, I, linestyle='-', linewidth=4)
plt.plot(t, V, linestyle='--', linewidth=4)
plt.plot(t, Vsteps, linestyle='-.', linewidth=4)
plt.xticks([0, 1], ['0', 'PW'], size=18)
plt.yticks([0, 1], size=18)
plt.xlim([-0.2, 1.2])
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.ylim([0, 1.2])
plt.ylabel('Voltage [V]', size=20)
if savefigs is True:
    plt.savefig('figures/Vcomp.eps')

Padaptive = np.multiply(I, I)
Pconst = np.multiply(I, V)
Pstep = np.multiply(I, Vsteps)

plt.figure()
plt.plot(t, Padaptive, linestyle='-', linewidth=4)
plt.plot(t, Pconst, linestyle='--', linewidth=4)
plt.plot(t, Pstep, linestyle='-.', linewidth=4)
plt.xticks([0, 1], ['0', 'PW'], size=18)
plt.yticks([0, 1], size=18)
plt.xlim([-0.2, 1.2])
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.ylim([0, 1.2])
plt.ylabel('Power [W]', size=20)
if savefigs is True:
    plt.savefig('figures/Pcomp.eps')

    
for w in wlist:
    plt.figure()
    wave, t = waveform(1, 0, 1, w, dt=0.001, t0=-0.2, tstop=1.2)
    if w == 'g':
      plt.arrow(0.2, 0.5, 0.8, 0, head_width=0.05, head_length=0.03, linewidth=5,length_includes_head=True, zorder=2)
      plt.arrow(0.8, 0.5, -0.8, 0, head_width=0.05, head_length=0.03, linewidth=5,length_includes_head=True, zorder=3)  
      plt.text(0.5, 0.25, '6$\sigma$', size=48, ha='center')
    plt.plot(t, wave, linewidth=8, zorder=1)
    plt.xticks([0, 1], ['0', 'PW'], fontsize=48)
    plt.yticks([])
    plt.xlim([-0.2, 1.2])
    plt.gca().spines['right'].set_visible(False)
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['left'].set_visible(False)
    plt.ylim([0, 1.2])
    plt.tight_layout()
    if savefigs is True:
        plt.savefig(f'figures/{w}_pulse.eps')
    
#%%
with open('datafiles/Vcdl.pkl', 'rb') as f:
    Vmax = pickle.load(f)

alist = Vmax['alpha'].unique()

for w in wlist:
    plt.plot(dur, econst[w], label = w, marker = next(marker), fillstyle='none', markevery=10)
plt.legend(legendlist)
plt.title(r'$\tau$ = $\infty$ ms')
plt.xlabel('Pulse Width [ms]')
plt.ylabel('$E_{constant}$ [pJ/k$\Omega$]')
plt.xlim([0, 1])
plt.ylim([50, 150])
plt.tight_layout()
if savefigs == True:
    plt.savefig('figures/Econst_alpha_0.eps')


for a in alist:
    plt.figure()
    print('alpha', a)
    for w in wlist:
        plt.plot(dur, econst[w]*Vmax.query(f'alpha == {a}')[w].to_numpy(), label=w, marker = next(marker), fillstyle='none', markevery=10)
        print(w, (min(econst[w]*Vmax.query(f'alpha == {a}')[w].to_numpy())/min(econst['r']*Vmax.query(f'alpha == {a}')['r'].to_numpy()) - 1)*100)
    plt.legend(legendlist)        
    plt.xlim([0,1])
    plt.ylim([50,150])
    if 1000/a >= 1:
        plt.title(fr'$\tau$ = {int(1000/a)} ms')
    else:
        plt.title(fr'$\tau$ = {1/int(a)*1e3} ms')
    plt.xlabel('Pulse Width [ms]')
    plt.ylabel('$E_{constant}$ [pJ/k$\Omega$]')
    plt.tight_layout()
    if savefigs == True:
        plt.savefig(f'figures/Econst_alpha_{int(a)}.eps')
    
wslist = [f'{i}_s' for i in wlist] 

for a in alist:
    fig, ax = plt.subplots()
    for ws in wslist:
        Vmax.query(f'alpha == {a} & pw <= 1').plot('pw', ws, ax=ax, logx=False, label=ws, marker = next(marker), fillstyle='none', markevery=10) #title=rf'$\alpha$ = {a}')
    ax.lines.pop(0) # remove the rectangular line
    ax.set_xlabel('Pulse Width [ms]')
    ax.set_ylabel(r"$\eta$ [%]")
    if 1000/a >= 1:
        ax.set_title(fr'$\tau$ = {int(1000/a)} ms')
    else:
        ax.set_title(fr'$\tau$ = {1/int(a)*1e3} ms')
    ax.legend(legendlist[1:])
    fig.tight_layout()
    if savefigs == True:
        plt.savefig(f'figures/alpha_{int(a)}.eps')

#%%
for w in data:
    wave, t = waveform(0.5, 0, np.abs(data[w][0.5]['th']), w, dt = 0.001, t0=-0.2, tstop = 0.7)
    t = t*1e3
    wave = wave/np.abs(data['r'][0.5]['th'])
    plt.plot(t, wave)  

plt.yticks([0, 0.5, 1, 1.5], size=18)
plt.xticks([0, 500], size=18)
plt.ylabel('$|I_{th}/I_0|$ ', size=20)
plt.xlabel('Time [$\mu$s]', size=20)

#%%
for w in data:
    wave, t = waveform(0.5, 0, 1, w, dt = 0.001, t0=-0.2, tstop = 0.8)
    t = t*1e3
    plt.plot(t, wave, linewidth=3)  
plt.grid(None)
plt.legend(['Rectangular', 'Gaussian', 'Half-Sine', 'Triangular', 'Ramp-Up', 'Ramp-Down'], fontsize=14)
plt.yticks([0, 1], size=18)
plt.xlim([-100, 900])
plt.xticks([0, 500], ['0', 'PW'], size=18)
plt.xlabel('Time [$\mu$s]', size=20)
plt.ylabel('Current [$\mu$A]', size=20)
plt.tight_layout()
plt.savefig('figures/shapes.eps')


#%%
I, t = waveform(1, 0, 3, 's', dt=0.001, t0=-0.2, tstop = 2.2, biphasic=True)
plt.plot(t, I, linewidth=4)
plt.xticks([])
plt.yticks([])
plt.xlim([-0.2, 2.2])
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.ylim([-3.2, 3.2])
plt.ylabel('Current', size=20)
plt.xlabel('Time', size=20)
if savefigs is True:
    plt.savefig('figures/Cpulse.eps')

plt.figure()
Vlevels = [0, 1, 2, 3]
V = [Vlevels[bisect_left(Vlevels, x)] for x in I]
V = np.array(V)
V[1300:2300] = -1*V[200:1200]
plt.plot(t, V, linewidth=4, color='orange')
plt.xticks([])
plt.yticks([])
plt.xlim([-0.2, 2.2])
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.ylim([-3.2, 3.2])
plt.ylabel('Voltage', size=20)
plt.xlabel('Time', size=20)

#%%
w = 'r'
pw = 1
amp = 1
alpha = 150
I, t = waveform(pw, 0, amp, w, dt=0.001, t0=-0.1, tstop=4.5, ipd=0.5, biphasic = True)
Q = np.array([simps(I[:i],t[:i]/1000) for i in range(1, len(I)+1)])
V = I + alpha*Q
fig, ax = plt.subplots()
fig.gca().spines['right'].set_visible(False)
fig.gca().spines['top'].set_visible(False)
# fig.gca().spines[:].set_linewidth(2)
ax.plot(t, np.abs(V), label=r'$V_{load}$', linewidth=4)
ax.fill_between(t, 1.5, np.abs(V), alpha = 1, where=(np.abs(I) > 0), color='lightgrey')
ax.plot(t, -V, label=r'$V_{tis}$', lw=4, ls='--')
ax.axhline(1.5, label=r'$V_{supply}$', color = 'green', linestyle='-.', linewidth=4)
ax.legend(fontsize=20)
ax.set_xlabel('Time', fontsize=20)
ax.set_ylabel('Voltage', fontsize=20)
ax.set_xticks([])
ax.set_yticks([0])
ax.set_xticklabels([])
ax.set_yticklabels([0])
ax.set_xlim([-0.1, 4.5])
fig.tight_layout()
if savefigs is True:
    fig.savefig('figures/Vsignal.eps')
fig, ax = plt.subplots()
ax.plot(t, np.abs(I), label=r'$I_{stim}$', linewidth=4)
ax.plot(t, -I, label=r'$I_{tis}$', lw=4, ls='--')
ax.legend(fontsize=20)
ax.set_xlabel('Time', fontsize=20)
ax.set_ylabel('Current', fontsize=20)
fig.gca().spines['right'].set_visible(False)
fig.gca().spines['top'].set_visible(False)
# fig.gca().spines[:].set_linewidth(2)
ax.set_xticks([])
ax.set_yticks([0])
ax.set_xticklabels([])
ax.set_yticklabels([0])
ax.set_xlim([-0.1, 4.5])
fig.tight_layout()
if savefigs is True:
    fig.savefig('figures/Isignal.eps')