# -*- coding: utf-8 -*-
"""
Created on Tue Oct 26 13:22:07 2021

@author: francescvarkev
"""
import pandas as pd
# import os
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import numpy as np


def meanSD(data):
    
    Smean = data.groupby(['shape', 'dur'])['th'].mean().abs() # Group by cell, shape & dur. Take mean for resulting values of e
    Snorm = min(Smean['r']) # Normalize values to min. e for shape 'r' in cell 17
    fig1, ax1 = plt.subplots()
    for w in data['shape'].unique():
        (Smean/Snorm)[w].plot(ax=ax1, label= w) # Take cross section for cell 17. Unstack to seperate shapes and plot values of dur vs mean(e)
    ax1.legend()
    ax1.set_title('Mean Normalized S/D, layer 5')
    ax1.set_xlabel('PW [ms]')
    ax1.set_ylabel('$I_{th}$/$I_{norm}$')
    ax1.set_ylim([0.9, max(Smean['r']/Snorm)])
    fig1.tight_layout()
    
    S50 = data.groupby(['shape', 'dur'])['th'].quantile(.5).unstack(level=0).abs()
    S25 = data.groupby(['shape', 'dur'])['th'].quantile(.25).unstack(level=0).abs()
    S75 = data.groupby(['shape', 'dur'])['th'].quantile(.75).unstack(level=0).abs()
    Snorm = min(S50['r'])
    
    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
    fig2, ax2 = plt.subplots(2, 3, sharex=True, sharey=True)
    i=0
    for w in data['shape'].unique():
        fig2.axes[i].set_prop_cycle(color=colors[i:])
        # fig2.axes[i].plot(S25[w]/Snorm, label = '25%')
        fig2.axes[i].plot(S50[w]/Snorm, label='median')
        # fig2.axes[i].plot(S75[w]/Snorm, label='75%')
        fig2.axes[i].fill_between(S25.index, S25[w]/Snorm, S75[w]/Snorm, alpha=0.2, label='25-75%')
        fig2.axes[i].set_title(w)
        
        i += 1
    ax2[0,0].legend()
    
    ax2[0,0].set_ylim([0, max(S50['r']/Snorm)])
    start, end = ax2[0,0].get_ylim()
    ax2[0,0].yaxis.set_ticks(np.arange(start, end, 2))
    fig2.supxlabel('PW [ms]')
    fig2.supylabel('$I_{th}$/$I_{norm}$')
    fig2.suptitle('Median Normalized S/D, layer 5')
    fig2.tight_layout()
    
    return fig1, ax1, fig2, ax2, Smean, S50

def meanQD(data):
    
    Qmean = data.groupby(['shape', 'dur'])['q'].mean().abs()
    Qnorm = min(Qmean['r'])
    fig1, ax1 = plt.subplots()
    for w in data['shape'].unique():
        (Qmean/Qnorm)[w].plot(ax=ax1, label= w) # Take cross section for cell 17. Unstack to seperate shapes and plot values of dur vs mean(e)
    ax1.legend()
    ax1.set_title('Mean Normalized Q/D, layer 5')
    ax1.set_xlabel('PW [ms]')
    ax1.set_ylabel('$Q_{th}$/$Q_0$')
    ax1.set_ylim([0.9, max(Qmean['r']/Qnorm)])
    fig1.tight_layout()
    
def meanED_adapt(data):
    
    Emean = data.groupby(['shape', 'dur'])['e'].mean() # Group by cell, shape & dur. Take mean for resulting values of e
    Enorm = min(Emean['r']) # Normalize values to min. e for shape 'r' in cell 17
    fig1, ax1 = plt.subplots()
    for w in data['shape'].unique():
        (Emean/Enorm)[w].plot(ax=ax1, label=w) # Take cross section for cell 17. Unstack to seperate shapes and plot values of dur vs mean(e)
    
    ax1.legend()
    ax1.set_title('Mean Normalized E/D, layer 5')
    ax1.set_xlabel('PW [ms]')
    ax1.set_ylabel('$E$/$E_{norm}$')
    ax1.set_ylim([0.8, max(Emean['r']/Enorm)])
    fig1.tight_layout()
    
    E50 = data.groupby(['shape', 'dur'])['e'].quantile(.5).unstack(level=0)
    E25 = data.groupby(['shape', 'dur'])['e'].quantile(.25).unstack(level=0)
    E75 = data.groupby(['shape', 'dur'])['e'].quantile(.75).unstack(level=0)
    Enorm = min(E50['r'])
    
    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
    fig2, ax2 = plt.subplots(2, int(np.ceil(len(data['shape'].unique())/2)), sharex=True)
    i=0
    for w in data['shape'].unique():
        fig2.axes[i].set_prop_cycle(color=colors[i:])
        # fig.axes[i].plot(E25[w]/Enorm, label = '25%')
        fig2.axes[i].plot(E50[w]/Enorm, label='median')
        # fig.axes[i].plot(E75[w]/Enorm, label='75%')
        fig2.axes[i].fill_between(E25.index, E25[w]/Enorm, E75[w]/Enorm, alpha=0.2, label='25-75%')
        fig2.axes[i].set_title(w)
        
        i += 1
   
    ax2[0,0].legend()
    
    # ax2[0,0].set_ylim([0, max(E50['r']/Enorm)])
    # start, end = ax2[0,0].get_ylim()
    # ax2[0,0].yaxis.set_ticks(np.arange(start, end, 2))
    fig2.supxlabel('PW [ms]')
    fig2.supylabel('E/$E_{norm}$')
    fig2.suptitle('Median Normalized E/D, layer 5')
    fig2.tight_layout()
    
    return fig1, ax1, fig2, ax2, Emean, E50
   
    
def eSave_const(data):
    
    df['E_const'] = df['q']*df['th'] # this uses th as max current needed + q as integral of I to calculate E for constant voltage.
    shapes = list(df['shape'].unique())
    shapes.remove('r')
    fig, ax = plt.subplots()
    # ax = (df.query('shape != "r"').groupby(['loc','shape'])['E_const'].min()/df.query('shape == "r"').groupby(['loc'])['E_const'].min() - 1).dropna().unstack(level=1).hist()
    E_save = {}
    for w in shapes:
        E_save[w] = (df.query(f'shape == "{w}"').groupby(['cell_id', 'loc','shape'])['E_const'].min()/df.query('shape == "r"').groupby(['cell_id', 'loc'])['E_const'].min() - 1).dropna()
        mean = round(E_save[w].mean()*100, 1)
        std = round(E_save[w].std()*100, 1)
        E_save[w].hist(density=True, label=f'{w}: {mean}$\pm${std}%', alpha=0.5, ax=ax)
    
    ax.set_prop_cycle(None)
    for w in shapes:
        E_save[w].plot.kde(label='_nolegend_', bw_method=1)
    ax.legend()
    ax.set_xlabel("$E$/$E_r$ - 1")
    ax.xaxis.set_major_formatter(mtick.PercentFormatter(1, 0))
    fig.suptitle("Relative Activation Energy, constant supply")
    
    low = min(E_save[w].min() for w in E_save)
    high = max(E_save[w].max() for w in E_save)
    ax.set_xlim([low, high])
    
    return fig, ax, E_save

def eSave_const_box(data):
    
    data['E_const'] = data['q']*data['th'] # this uses th as max current needed + q as integral of I to calculate E for constant voltage.
    shapes = list(data['shape'].unique())
    shapes.remove('r')
    fig, ax = plt.subplots()
    E_save = pd.DataFrame()
    for w in shapes:
        E_save[w] = (data.query(f'shape == "{w}"').groupby(['cell_id', 'loc','shape'])['E_const'].min()/data.query('shape == "r"').groupby(['cell_id', 'loc'])['E_const'].min() - 1).dropna().reset_index(['shape'])['E_const']

    E_save.plot.box(ax=ax)
    ax.yaxis.set_major_formatter(mtick.PercentFormatter(1, 0))
    ax.set_ylabel("$E$/$E_r$ - 1")
    fig.suptitle("Relative Activation Energy, constant supply")

    return fig, ax, E_save

def eSave_adaptive_box(data):
    
    shapes = list(data['shape'].unique())
    shapes.remove('r')
    fig, ax = plt.subplots()
    E_save = pd.DataFrame()
    for w in shapes:
        E_save[w] = (data.query(f'shape == "{w}"').groupby(['cell_id', 'loc','shape'])['e'].min()/data.query('shape == "r"').groupby(['cell_id', 'loc'])['e'].min() - 1).dropna().reset_index(['shape'])['e']

    E_save.plot.box(ax=ax)
    ax.yaxis.set_major_formatter(mtick.PercentFormatter(1, 0))
    ax.set_ylabel("$E$/$E_r$ - 1")
    fig.suptitle("Relative Activation Energy, adaptive supply")

    return fig, ax, E_save


def eSaveCdl_const(data):
    
    
    fig, ax = plt.subplots()
    # ax = (df.query('shape != "r"').groupby(['loc','shape'])['E_const'].min()/df.query('shape == "r"').groupby(['loc'])['E_const'].min() - 1).dropna().unstack(level=1).hist()
    E_save = {}
    for w in ['g', 's', 't']:
        E_save[w] = (df.query(f'shape == "{w}"').groupby(['cell_id', 'loc','shape'])['eCdl_const'].min()/df.query('shape == "r"').groupby(['cell_id', 'loc'])['eCdl_const'].min() - 1).dropna()
        mean = round(E_save[w].mean()*100, 1)
        std = round(E_save[w].std()*100, 1)
        E_save[w].hist(density=True, label=f'{w}: {mean}$\pm${std}%', alpha=0.5, ax=ax)
    
    ax.set_prop_cycle(None)
    for w in ['g', 's', 't']:
        E_save[w].plot.kde(label='_nolegend_')
    ax.legend()
    ax.set_xlabel("$E$/$E_r$ - 1")
    ax.xaxis.set_major_formatter(mtick.PercentFormatter(1, 0))
    fig.suptitle("Activation Energy, incl. $V_{cdl}$, constant supply")
    
    low = min(E_save[w].min() for w in E_save)
    high = max(E_save[w].max() for w in E_save)
    ax.set_xlim([low, high])
    
    return fig, ax, E_save
    
def eSaveCdl_adapt(data):
    
    
    fig, ax = plt.subplots()
    # ax = (df.query('shape != "r"').groupby(['loc','shape'])['E_const'].min()/df.query('shape == "r"').groupby(['loc'])['E_const'].min() - 1).dropna().unstack(level=1).hist()
    E_save = {}
    for w in ['g', 's', 't']:
        E_save[w] = (df.query(f'shape == "{w}"').groupby(['cell_id', 'loc','shape'])['eCdl_adapt'].min()/df.query('shape == "r"').groupby(['cell_id', 'loc'])['eCdl_adapt'].min() - 1).dropna()
        mean = round(E_save[w].mean()*100, 1)
        std = round(E_save[w].std()*100, 1)
        E_save[w].hist(density=True, label=f'{w}: {mean}$\pm${std}%', alpha=0.5, ax=ax)
    
    ax.set_prop_cycle(None)
    for w in ['g', 's', 't']:
        E_save[w].plot.kde(label='_nolegend_')
    ax.legend()
    ax.set_xlabel("$E$/$E_r$ - 1")
    ax.xaxis.set_major_formatter(mtick.PercentFormatter(1, 0))
    fig.suptitle("Activation Energy, incl. $V_{cdl}$,  adaptive supply")
    
    low = min(E_save[w].min() for w in E_save)
    high = max(E_save[w].max() for w in E_save)
    ax.set_xlim([low, high])
    
    return fig, ax, E_save

def pwShift_const(data):   
    fig, ax = plt.subplots()
    # ax = (df.query('shape != "r"').groupby(['loc','shape'])['E_const'].min()/df.query('shape == "r"').groupby(['loc'])['E_const'].min() - 1).dropna().unstack(level=1).hist()
    PW_shift = {}
    for w in ['g', 's', 't']:
        PW_shift[w] = (df.loc[df.query(f'shape == "{w}"').groupby(['cell_id', 'loc', 'shape'])['E_const'].idxmin().dropna()].reset_index()['dur']/df.loc[df.query('shape == "r"').groupby(['cell_id', 'loc', 'shape'])['E_const'].idxmin().dropna()].reset_index()['dur'])
        mean = round(PW_shift[w].mean(), 2)
        std = round(PW_shift[w].std(), 2)
        PW_shift[w].hist(density=True, label= w + ': ' + str(mean) + '$\pm$' + str(std), alpha=0.5, ax=ax)
    
    ax.set_prop_cycle(None)
    # ax2 = ax.twinx()
    for w in ['g', 's', 't']:
        PW_shift[w].plot.kde(label='_nolegend_', bw_method=0.5)
    ax.legend()
    ax.set_xlabel("$PW_{opt}$/$PW_{opt,r}$")
    low = min(PW_shift[w].min() for w in PW_shift)
    high = max(PW_shift[w].max() for w in PW_shift)
    ax.set_xlim([low, high])
    fig.suptitle("Relative PW Shift with Contant Supply")
    
    return fig, ax, PW_shift

def eSave_adaptive(df):              
    fig, ax = plt.subplots()
    shapes = list(df['shape'].unique())
    shapes.remove('r')
    E_save = {}
    for w in shapes:
        E_save[w] = ((df.query(f'shape == "{w}"').groupby(['cell_id', 'loc','shape'])['e'].min()/df.query('shape == "r"').groupby(['cell_id', 'loc'])['e'].min()) - 1).dropna()
        mean = round(E_save[w].mean()*100, 1)
        std = round(E_save[w].std()*100, 1)
        E_save[w].hist(density=True, label=f'{w}: {mean}$\pm${std}%', alpha=0.5,ax=ax)
    
    ax.set_prop_cycle(None)
    for w in shapes:
        E_save[w].plot.kde(label='_nolegend_', bw_method=1)
    ax.legend()
    low = min(E_save[w].min() for w in E_save)
    high = max(E_save[w].max() for w in E_save)
    ax.set_xlim([low, high])
    ax.set_xlabel("$E$/$E_r$ - 1")
    ax.xaxis.set_major_formatter(mtick.PercentFormatter(1, 0))
    fig.suptitle("Relative Activation Energy, adaptive supply")     
    
    return fig, ax, E_save
         



def pwShift_adaptive(df):
    fig, ax = plt.subplots()
    PW_shift = {}
    
    for w in ['g', 's', 't']:
        PW_shift[w] = (df.loc[df.query(f'shape == "{w}"').groupby(['cell_id', 'loc', 'shape'])['e'].idxmin().dropna()].reset_index()['dur']/df.loc[df.query('shape == "r"').groupby(['cell_id', 'loc', 'shape'])['e'].idxmin().dropna()].reset_index()['dur'])
        mean = round(PW_shift[w].mean(), 2)
        std = round(PW_shift[w].std(), 2)
        PW_shift[w].hist(density=True, label= w + ': ' + str(mean) + '$\pm$' + str(std), alpha=0.5, ax=ax)
    
    ax.set_prop_cycle(None)
    # ax2 = ax.twinx()
    for w in ['g', 's', 't']:
        PW_shift[w].plot.kde(label='_nolegend_', bw_method=0.5)
    ax.legend()
    ax.set_xlabel("$PW_{opt}$/$PW_{opt,r}$")
    fig.suptitle("Relative PW Shift with Adaptive Supply")
    low = min(PW_shift[w].min() for w in PW_shift)
    high = max(PW_shift[w].max() for w in PW_shift)
    ax.set_xlim([low, high])
    
    return fig, ax, PW_shift
    
def eSave_old(opt):                     
    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
    
    fig, ax = plt.subplots(5,1, sharex=True, figsize=(10,10))
    i = 0
    
    for cell_id in opt['cell_id'].unique():
        ax[i].set_prop_cycle(color=colors[1:])
        for w in ['g', 's', 't']:
            mean = round(opt.query(f'cell_id == {cell_id}')[f'save_e_{w}'].mean()*100, 1)
            std = round(opt.query(f'cell_id == {cell_id}')[f'save_e_{w}'].std()*100, 1)
            ax[i].hist(opt.query(f'cell_id == {cell_id}')[f'save_e_{w}'], density=True, label=f"{w}: {mean}, $\pm${std}%" , alpha=0.5)
            
        ax[i].set_prop_cycle(color=colors[1:])
        for w in ['g', 's', 't']:
            opt.query(f'cell_id == {cell_id}')[f'save_e_{w}'].plot.kde(ax=ax[i], bw_method=1, label='_nolegend_')
        
        ax[i].set_title(cell_id)
        ax[i].legend()
        ax[i].xaxis.set_major_formatter(mtick.PercentFormatter(1, 0))
        i += 1
        
    fig.tight_layout()
        # mu, sigma = stats.norm.fit(E_save[w])
        # points = np.linspace(stats.norm.ppf(0.01, loc=mu, scale=sigma), stats.norm.ppf(0.9999,loc=mu,scale=sigma),100)
        # pdf = stats.norm.pdf(points,loc=mu,scale=sigma)
        # plt.plot(points, pdf, label = str(w) + '_pdf')
    fig, ax = plt.subplots()
    ax.set_prop_cycle(color=colors[1:])
    for w in ['g', 's', 't']:
        mean = round(opt[f'save_e_{w}'].mean()*100, 1)
        std = round(opt[f'save_e_{w}'].std()*100, 1)
        ax.hist(opt[f'save_e_{w}'], bins = 20, density=True, label=f"{w}: {mean}, $\pm${std}%", alpha=0.5)
    ax.set_prop_cycle(color=colors[1:])
    for w in ['g', 's', 't']:
        opt[f'save_e_{w}'].plot.kde(ax=ax, bw_method=1, label='_nolegend_')

        
    ax.legend()
    ax.xaxis.set_major_formatter(mtick.PercentFormatter(1, 0))
    fig.tight_layout()
    # mu, sigma = stats.norm.fit(E_save[w])
    # plt.legend()
    # plt.title('Energy savings with respect to rectangular pulse')
    # plt.gca().xaxis.set_major_formatter(mtick.PercentFormatter(1, 0))
    # plt.xlabel('1-$E_x$/$E_r$', size =14)  
    E_save = {}
    for w in ['g', 's', 't']:
        E_save[w] = opt[f'save_e_{w}']
        
    low = min(E_save[w].min() for w in E_save)
    high = max(E_save[w].max() for w in E_save)
    ax.set_xlim([low, high])
    return E_save

def pwShift_old(opt):
    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
    
    fig, ax = plt.subplots(5,1, sharex=True, figsize=(10,10))
    i = 0
    
    for cell_id in opt['cell_id'].unique():
        ax[i].set_prop_cycle(color=colors[1:])
        for w in ['g', 's', 't']:
            mean = round(opt.query(f'cell_id == {cell_id}')[f'shift_pw_{w}'].mean(), 2)
            std = round(opt.query(f'cell_id == {cell_id}')[f'shift_pw_{w}'].std(), 2)
            ax[i].hist(opt.query(f'cell_id == {cell_id}')[f'shift_pw_{w}'], density=True, label=f"{w}: {mean}, $\pm${std}", alpha=0.5)
            
        ax[i].set_prop_cycle(color=colors[1:])
        for w in ['g', 's', 't']:
            opt.query(f'cell_id == {cell_id}')[f'shift_pw_{w}'].plot.kde(ax=ax[i], bw_method=1, label='_nolegend_')
        ax[i].set_title(cell_id)
        ax[i].legend()
        i += 1
        
    
        
    fig.tight_layout()
    
    fig, ax = plt.subplots()
    ax.set_prop_cycle(color=colors[1:])
    for w in ['g', 's', 't']:
        mean = round(opt[f'shift_pw_{w}'].mean(), 2)
        std = round(opt[f'shift_pw_{w}'].std(), 2)
        ax.hist(opt[f'shift_pw_{w}'], bins = 20, density=True, label=f"{w}: {mean}, $\pm${std}", alpha=0.5)
    ax.set_prop_cycle(color=colors[1:])
    for w in ['g', 's', 't']:
        opt[f'shift_pw_{w}'].plot.kde(ax=ax, bw_method=1, label='_nolegend_')
    ax.legend()
    fig.tight_layout()
        # mu, sigma = stats.norm.fit(E_save[w])
        # points = np.linspace(stats.norm.ppf(0.01, loc=mu, scale=sigma), stats.norm.ppf(0.9999,loc=mu,scale=sigma),100)
        # p = stats.norm.pdf(points,loc=mu,scale=sigma)
        # plt.plot(points, pdf, label = str(w) + '_pdf')
    # plt.legend()
    # plt.title('Energy savings with respect to rectangular pulse')
    # plt.gca().xaxis.set_major_formatter(mtick.PercentFormatter(1, 0))
    # plt.xlabel('1-$E_x$/$E_r$', size =14)  
    PW_shift = {}
    for w in ['g', 's', 't']:
        PW_shift[w] = opt[f'shift_pw_{w}']
    return PW_shift    

def plotED(df, cell_id = 20, loc = None, shape = None):
    
    if loc is None:
        loc = df.query(f'cell_id == {cell_id}')['loc'].iloc[0]
        
    if shape is None:
        shape = df['shape'].unique()
    else:
        shape = list(shape)
    fig, ax = plt.subplots()
    
    data = df.query(f'cell_id == {cell_id} and loc =="{loc}"').sort_values('dur')
    Enorm = min(data.query('shape == "r"')['e'])
    for w in df['shape'].unique():
        ax.plot(data.query(f'shape == "{w}"')['dur'], data.query(f'shape == "{w}"')['e']/Enorm, label=w)

    return data, fig, ax

def plotSD(df, cell_id = 20, loc = None, shape = None):
    
    if loc is None:
        loc = df.query(f'cell_id == {cell_id}')['loc'].iloc[0]
        
    if shape is None:
        shape = df['shape'].unique()
    else:
        shape = list(shape)
    fig, ax = plt.subplots()
    
    data = df.query(f'cell_id == {cell_id} and loc =="{loc}"').sort_values('dur')
    Snorm = min(data.query('shape == "r"')['th'].abs())
    for w in df['shape'].unique():
        ax.plot(data.query(f'shape == "{w}"')['dur'], data.query(f'shape == "{w}"')['th'].abs()/Snorm, label = w)

    return data, fig, ax


def meanXD(data, x, name=None, plot=True):
    
    if name is None:
        name = x
        
    Xmean = data.groupby(['shape', 'dur'])[x].mean().abs() # Group by cell, shape & dur. Take mean for resulting values of e
    Xnorm = min(Xmean['r']) # Normalize values to min. e for shape 'r' in cell 17

    fig1, ax1 = plt.subplots()
    for w in data['shape'].unique():
        (Xmean/Xnorm)[w].plot(ax=ax1, label=w) # Take cross section for cell 17. Unstack to seperate shapes and plot values of dur vs mean(e)
    
    ax1.legend()
    ax1.set_title(f'Mean Normalized {name}/D, layer 5')
    ax1.set_xlabel('PW [ms]')
    ax1.set_ylabel(f'${name}$/${name}_{{0}}$')
    ax1.set_ylim([0.8, max(Xmean['r']/Xnorm)])
    fig1.tight_layout()
    
    X50 = data.groupby(['shape', 'dur'])[x].quantile(.5).unstack(level=0).abs()  
    X25 = data.groupby(['shape', 'dur'])[x].quantile(.25).unstack(level=0).abs()
    X75 = data.groupby(['shape', 'dur'])[x].quantile(.75).unstack(level=0).abs()
    Xnorm = min(X50['r'])

    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
    fig2, ax2 = plt.subplots(2, 2, sharex=True)
    i=0
    for w in data['shape'].unique():
        fig2.axes[i].set_prop_cycle(color=colors[i:])
        # fig.axes[i].plot(E25[w]/Enorm, label = '25%')
        fig2.axes[i].plot(X50[w]/Xnorm, label='median')
        # fig.axes[i].plot(E75[w]/Enorm, label='75%')
        fig2.axes[i].fill_between(X25.index, X25[w]/Xnorm, X75[w]/Xnorm, alpha=0.2, label='25-75%')
        fig2.axes[i].set_title(w)
        
        i += 1
   
    ax2[0,0].legend()
    
    # ax2[0,0].set_ylim([0, max(E50['r']/Enorm)])
    # start, end = ax2[0,0].get_ylim()
    # ax2[0,0].yaxis.set_ticks(np.arange(start, end, 2))
    fig2.supxlabel('PW [ms]')
    fig2.supylabel(f'${name}$/${name}_{{0}}$')
    fig2.suptitle(f'Median Normalized {name}/D, layer 5')
    fig2.tight_layout()
        
    return fig1, ax1, fig2, ax2, Xmean, X50


def ED_locsplit(df, cell_ids):
    wlist = df['shape'].unique()
    for cell_id in cell_ids:
        loclist = []
        for loc in df.query(f'cell_id == {cell_id}')['loc'].unique():
            loclist.append(list(map(int, loc[1:-1].split(','))))
        locarray = np.array(loclist)
        zlist = np.unique(locarray[:, 2])
        
        for z in zlist:
            sublocarray = locarray[np.where(locarray[:,2] == z)]
            subxlist = np.unique(sublocarray[:,0])
            subylist = np.unique(sublocarray[:,1])
            
            fig, ax = plt.subplots(len(subylist), len(subxlist), sharex=True, sharey=False, figsize=(70,42))
            # cm = plt.get_cmap('Blues')
    
            for n in range(len(subxlist)):
                for m in range(len(subylist)):
                    for i in range(len(wlist)):
                    # ax[n,i].set_prop_cycle(color=[cm(1.*j/num_colors) for j in range(num_colors)])
                    
                        w = wlist[i]
                        x = subxlist[n]
                        y = subylist[m]
                        dat = df.query(f'cell_id == {cell_id} & loc == "({x}, {y}, {z})" & shape == "{w}"').sort_values('dur')
                        
                        ax[m,n].plot(dat['dur'], dat['e'], label=w)
                        # ax[m,n].set_ylim(0, 25)
                        
            for n in range(len(subxlist)):
                x = subxlist[n]
                ax[0, n].title.set_text('x = ' +str(x))
                
            for m in range(len(subylist)):
                y = subylist[m]
                ax[m, 0].set_ylabel('y = ' +str(y))    
                
            ax[0,0].legend()
            # fig.suptitle('X position')
            # plt.xlabel('Y position')
            fig.suptitle(f'Energy-Duration, cell={cell_id}, z={z}', fontsize=20)
            fig.tight_layout(rect=[0, 0.03, 1, 0.98])
            
    return loclist, z, subxlist, subylist, sublocarray, dat

def optpwlists(df, cell_ids):
    wlist = df['shape'].unique()
    

    optpw = {}
    for w in wlist:
        optpw[w] = []
        for cell_id in cell_ids:
            for loc in df.query(f'cell_id == {cell_id}')['loc'].unique():
                i = df.query(f'cell_id == {cell_id} & loc == "{loc}" & shape == "{w}"')['e'].idxmin()
                optpw[w].append(df.iloc[i].dur)
    return optpw
#%% 
if __name__ == '__main__':
  
    df = pd.read_pickle('DataFilesFinal/DataFrame_all.pkl') 
    df.loc[df['ex'] == 0, ['th', 'e', 'q']] = np.nan # if ex == 0, set th, e and q tot nan
    df['E_const'] = df['q']*df['th']
    # df_sum = pd.read_pickle('DataFrame_sum.pkl')
    # df_opt = pd.read_pickle('DataFrame_opt.pkl')
    
    # if not(os.path.exists('Figures/' + str(cell_id))):
    #        os.mkdir('Figures/' + str(cell_id))
    # _ = meanSD(df)
       
    # _ = meanED_adapt(df)
    
    # _ = eSave_adaptive(df)
    # _ = eSave_const(df)
    
    # _ = pwShift_adaptive(df)
    # _ = pwShift_const(df)
    
    # plotSDFitted(df)
    # plotEDFitted(df)
    
    # E_save_fit = energySavingsFitHist(data, meta, opt)
    # optPwFitHist(data, meta, opt)
    
#%% 
# def energySavingsFitHist(data, meta, opt):
#     plt.figure()
#     loclist = meta['loclist']
#     durlist = meta['dlist']
#     wlist = meta['shapelist']
#     cell_id = int(meta['cell_id'])
#     min_e = {}
#     for w in wlist:
#         min_e[w] = {}
#         for loc in loclist:
#             if opt[loc]['pw_' + str(w)] != durlist[-1]:
#                 min_e[w][loc] = min(data[cell_id][loc][w][0][0]['fit']['e'])
    
#     E_save = {}
#     for w in wlist:
#         if w not in ['r']:
#             E_save[w] = []
#             for loc in loclist:
#                 if loc in min_e['r'] and loc in min_e[w]:
#                     E_save[w].append(1 - min_e[w][loc]/min_e['r'][loc])
                                   
#     colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
#     plt.gca().set_prop_cycle(color=colors[1:])
#     for w in E_save:
#         plt.hist(E_save[w], bins=20, density=True, label = str(w) + ' n=' + str(len(E_save[w])))
    
#     plt.legend()
#     plt.title('Energy savings wrt rectangular pulse, fitted data')
#     plt.gca().xaxis.set_major_formatter(mtick.PercentFormatter(1, 0))
#     plt.xlabel('1-$E_x$/$E_r$', size =14) 
    
#     return E_save


    
# def optPwFitHist(data, meta, opt):
#     plt.figure()
#     loclist = meta['loclist']
#     wlist = meta['shapelist']
#     cell_id = int(meta['cell_id'])
#     pw_opt = {}
#     for w in wlist:
#         pw_opt[w] = []

#         for loc in loclist:
#             i = np.where(data[cell_id][loc][w][0][0]['fit']['e'] == min(data[cell_id][loc][w][0][0]['fit']['e']))[0][0]
#             pw_opt[w].append(data[cell_id][loc][w][0][0]['fit']['d'][i])
                
#         plt.hist(pw_opt[w], label = w)
    
#     plt.legend()
#     plt.title('Optimal PW, fitted data', size=14)
#     plt.xlabel('PW [ms]',size=14)
#     # plt.savefig('Figures\\' + session + '\OptPWHist.eps')


# def plotSDFitted(data, meta):
#     #Using fitted data from SD curve, plot and compare ED data to measured data
#     wlist = meta['shapelist']
#     loclist = meta['loclist']
#     cell_id = int(meta['cell_id'])
#     Smean = {}
#     S25 = {}
#     S50 = {}
#     S75 = {}
#     for w in wlist:
#         Slist = []
#         for loc in loclist:
#             Slist.append(np.abs(data[cell_id][loc][w][0][0]['fit']['s']))
#         Sarrays = [np.array(x) for x in Slist]
#         Smean[w] = [np.mean(k) for k in zip(*Sarrays)]
#         # fig.axes[i].plot(d, Smean)
        
#         # plt.figure() # Create normalized SD curve for r as in fig. 5d of Aberra et al. 
#         S25[w] = [np.percentile(k, 25) for k in zip(*Sarrays)]
#         S50[w] = [np.percentile(k, 50) for k in zip(*Sarrays)]
#         S75[w] = [np.percentile(k, 75) for k in zip(*Sarrays)]
    
#     xdurlist = data[cell_id][loc][w][0][0]['fit']['d']
#     Snorm = min(Smean['r'])
#     plt.figure()
#     for w in wlist:
#         plt.plot(xdurlist, Smean[w]/Snorm, label = w)
#     plt.legend()
#     plt.title('Normalized Mean Strength-Duration Curves')
    
#     plt.savefig('Figures/' + meta['cell_id'] + '/Strength-Duration_Mean_Fit.png')
    
#     Snorm = min(S50['r'])
#     fig, ax = plt.subplots(2, 2, sharex=True, sharey=True)
#     i=0
#     for w in wlist:
#         fig.axes[i].plot(xdurlist, S25[w]/Snorm, label = '25%')
#         fig.axes[i].plot(xdurlist, S50[w]/Snorm, label='median')
#         fig.axes[i].plot(xdurlist, S75[w]/Snorm, label='75%')
#         fig.axes[i].set_title(w)
#         fig.axes[i].set_xlabel('PW [ms]')
#         fig.axes[i].set_ylabel('$I_{th}$/$I_{norm}$')
#         i += 1
#     fig.axes[0].set_ylim([0, max(S75['r'])/Snorm])
#     fig.axes[0].legend()
#     fig.suptitle('Normalized Strength-duration curves')
#     fig.tight_layout()
    
#     fig.savefig('Figures/' + meta['cell_id'] + '/Strength-Duration_Seperated_Fit.png')
    
#     plt.figure()
    
#     for w in wlist:
#         plt.plot(xdurlist, S50[w]/Snorm, label= w)
#     plt.legend()
#     plt.title('Normalized Median Strength-Duration Curves')
#     plt.savefig('Figures/' + meta['cell_id'] + '/Strength-Duration_Median_Fit.png')
#     return 


# def plotEDFitted(data, meta):
#     #Using fitted from SD curve, plot and compare ED data to measured data
#     wlist = meta['shapelist']
#     loclist = meta['loclist']
#     cell_id = int(meta['cell_id'])
#     Emean = {}
#     E25 = {}
#     E50 = {}
#     E75 = {}
#     for w in wlist:
#         Elist = []
#         for loc in loclist:
#             Elist.append(data[cell_id][loc][w][0][0]['fit']['e'])
#         Earrays = [np.array(x) for x in Elist]
#         Emean[w] = [np.mean(k) for k in zip(*Earrays)]
#         # fig.axes[i].plot(d, Smean)
        
#         # plt.figure() # Create normalized SD curve for r as in fig. 5d of Aberra et al. 
#         E25[w] = [np.percentile(k, 25) for k in zip(*Earrays)]
#         E50[w] = [np.percentile(k, 50) for k in zip(*Earrays)]
#         E75[w] = [np.percentile(k, 75) for k in zip(*Earrays)]
    
#     xdurlist = data[cell_id][loc][w][0][0]['fit']['d']
#     plt.savefig('Figures/' + meta['cell_id'] + '/Energy-Duration_Mean_Fit.png')
#     plt.figure()
#     Enorm = min(Emean['r'])
#     for w in wlist:
#         plt.plot(xdurlist, Emean[w]/Enorm, label= w)
#     plt.legend()
#     plt.title('Fitted Mean Energy-Duration Curves')
#     Enorm = min(E50['r'])
#     fig, ax = plt.subplots(2, 2, sharex=True, sharey=True)
#     i=0
#     for w in wlist:
#         fig.axes[i].plot(xdurlist, E25[w]/Enorm, label = '25%')
#         fig.axes[i].plot(xdurlist, E50[w]/Enorm, label='median')
#         fig.axes[i].plot(xdurlist, E75[w]/Enorm, label='75%')
#         fig.axes[i].set_title(w)
#         fig.axes[i].set_xlabel('PW [ms]')
#         fig.axes[i].set_ylabel('E/$E_{norm}$')
#         i += 1
#     fig.axes[0].set_ylim([0, max(E75['r'])/Enorm])
#     fig.axes[0].legend()
#     fig.suptitle('Fitted Energy-duration curves')
#     fig.tight_layout()
#     fig.savefig('Figures/' + meta['cell_id'] + '/Energy-Duration_Seperated_Fit.png')
    
#     plt.figure()
    
#     for w in wlist:
#         plt.plot(xdurlist, E50[w]/Enorm, label= w)
#     plt.legend()
#     plt.title('Fitted Median Energy-Duration Curves')
#     plt.savefig('Figures/' + meta['cell_id'] + '/Strength-Duration_Median_Fit.png')
    
#     return 