# -*- coding: utf-8 -*-
"""
Created on Mon Sep 27 17:27:02 2021

@author: francescvarkev
"""
import pickle
from time import time

from model import calcminr, initCell, createRecordingDict
from runSim import SDdata
from neuron import h
from config import setMeta

def savetoPickle(filename, data, meta):
    input_dict = {}
    input_dict['data'] = data
    input_dict['meta'] = meta
    cell_id = meta['cell_id']
    a_file = open("DatafilesFinal/" + str(cell_id) + '/' + str(filename) + ".pkl", "wb")
    pickle.dump(input_dict, a_file)
    a_file.close()
    print('data saved to file ' + filename)

def savetoPickle_biphasic(filename, data, meta):
    input_dict = {}
    input_dict['data'] = data
    input_dict['meta'] = meta
    cell_id = meta['cell_id']
    a_file = open("Datafiles/Biphasic/" + str(cell_id) + '/' + str(filename) + ".pkl", "wb")
    pickle.dump(input_dict, a_file)
    a_file.close()
    print('data saved to file ' + filename)


def loadPickle(filename):
    a_file = open(filename, "rb")
    output = pickle.load(a_file)
    a_file.close()
    
    return output
    
def run_and_save_locSweep(filename, cell_id, config): 
    t_start = time()
    
    ## Initialize the cell
    initCell(cell_id)
    meta = setMeta(config, cell_id)
    rec = createRecordingDict()
    result = {}
    
    for w in meta['shapelist']:
        result[w] = {} 
        for polarity in meta['polaritylist']:
            
            result[w][polarity] = {}
            
            for phi in meta['anglelist']:
                result[w][polarity][phi] = {}
                
                for loc in meta['loclist']:
                    result[w][polarity][phi][loc] = {}
                    result[w][polarity][phi][loc]['dist'] = round(calcminr(loc[0], loc[1], loc[2])[0], 2)

                    print(str(w) + ', ' + str(polarity) + ', ' + str(phi), ', ', str([loc]))
                    
                    SD = SDdata(meta['min_dur'], meta['max_dur'], meta['dur_step'], w, rec, x=loc[0], y=loc[1], z=loc[2], dt=meta['dt'], depth=meta['resolution_depth'], bipolar=polarity, phi=phi)
                    if SD != {}:
                        meta['loclist'].remove(loc) # remove location from loclist if no threshold was found in SDdata function
                    else:
                        result[w][polarity][phi][loc]['SD'] = SD
                        d = list(result[w][polarity][phi][loc]['SD'])
                        th = []
                        e = []
                        q = []
                        for dur in d:
                            th.append(result[w][polarity][phi][loc]['SD'][dur]['th'])
                            e.append(result[w][polarity][phi][loc]['SD'][dur]['e'])
                            q.append(result[w][polarity][phi][loc]['SD'][dur]['q'])
                        result[w][polarity][phi][loc]['s'] = th
                        result[w][polarity][phi][loc]['d'] = d
                        result[w][polarity][phi][loc]['e'] = e
                        result[w][polarity][phi][loc]['q'] = q
                        
    t_stop = time()
    t_elapsed = (t_stop-t_start)/3600
    meta['elapsed_time'] = t_elapsed
    print('Time elapsed: ' + str(t_elapsed))
    
    savetoPickle(filename, result, meta)
    
    
def run_and_save_sigmaSweep(filename, cell_id, config): 
    t_start = time()
    
    ## Initialize the cell
    initCell(cell_id)
    meta = setMeta(config, cell_id)

    rec = createRecordingDict()
    result = {}
    w = meta['shapelist'][0]
    polarity = meta['polaritylist'][0]
    phi = meta['anglelist'][0]
    loc = meta['loclist'][0]
    dist = round(calcminr(loc[0], loc[1], loc[2])[0], 2)
    
    for sigma in meta['sigmalist']:
        print(str(sigma))
        result[sigma] = {}
        result[sigma]['dist'] = dist
        
        SD = SDdata(meta['min_dur'], meta['max_dur'], meta['dur_step'], w, rec, x=loc[0], y=loc[1], z=loc[2], dt=meta['dt'], depth=meta['resolution_depth'], bipolar=polarity, phi=phi, sigma=sigma)
        if SD == {}:
            meta['loclist'].remove(loc) # remove location from loclist if no threshold was found in SDdata function
        else:
            result[sigma]['SD'] = SD
            d = list(result[sigma]['SD'])
            meta['duration'] = d
            th = []
            e = []
            q = []
            for dur in d:
                th.append(result[sigma]['SD'][dur]['th'])
                e.append(result[sigma]['SD'][dur]['e'])
                q.append(result[sigma]['SD'][dur]['q'])
            result[sigma]['s'] = th
            result[sigma]['d'] = d
            result[sigma]['e'] = e
            result[sigma]['q'] = q
                        
    t_stop = time()
    t_elapsed = (t_stop-t_start)/3600
    meta['elapsed_time'] = t_elapsed
    print('Time elapsed: ' + str(t_elapsed))
    
    savetoPickle(filename, result, meta)
 

def run_and_save_tauSweep(filename, cell_id, config): 
    from neuron import h
    t_start = time()
    
    ## Initialize the cell
    initCell(cell_id)
    meta = setMeta(config, cell_id)

    rec = createRecordingDict()
    result = {}
    w = meta['shapelist'][0]
    polarity = meta['polaritylist'][0]
    phi = meta['anglelist'][0]
    loc = meta['loclist'][0]
    dist = round(calcminr(loc[0], loc[1], loc[2])[0], 2)
    
    for tau in meta['taulist']:
        print(str(tau))
        result[tau] = {}
        result[tau]['dist'] = dist
        
        SD = SDdata(meta['min_dur'], meta['max_dur'], meta['dur_step'], w, rec, x=loc[0], y=loc[1], z=loc[2], dt=meta['dt'], depth=meta['resolution_depth'], bipolar=polarity, phi=phi, tau=tau)
        if SD == {}:
            meta['loclist'].remove(loc) # remove location from loclist if no threshold was found in SDdata function
        else:
            result[tau]['SD'] = SD
            d = list(result[tau]['SD'])
            meta['duration'] = d
            th = []
            e = []
            q = []
            for dur in d:
                th.append(result[tau]['SD'][dur]['th'])
                e.append(result[tau]['SD'][dur]['e'])
                q.append(result[tau]['SD'][dur]['q'])
            result[tau]['s'] = th
            result[tau]['d'] = d
            result[tau]['e'] = e
            result[tau]['q'] = q
                        
    t_stop = time()
    t_elapsed = (t_stop-t_start)/3600
    meta['elapsed_time'] = t_elapsed
    print('Time elapsed: ' + str(t_elapsed))
    
    savetoPickle(filename, result, meta)
    
def run_and_save_parallel_locSweep(loc, cell_id, meta): 
    t_start = time()
    
    ## Initialize the cell
    initCell(cell_id)
    # print('test1')
    rec = createRecordingDict()
    result = {}
    
    for w in meta['shapelist']:
        print(f'start {w}')
        result[w] = {} 
        for polarity in meta['polaritylist']:
            
            result[w][polarity] = {}
            
            for phi in meta['anglelist']:
                result[w][polarity][phi] = {}
                
                result[w][polarity][phi]
                
                result[w][polarity][phi]['dist'] = round(calcminr(loc[0], loc[1], loc[2])[0], 2)

                # print(str(w) + ', ' + str(polarity) + ', ' + str(phi), ', ', str([loc]))
                SD = SDdata(meta['dlist'], w, rec, x=loc[0], y=loc[1], z=loc[2], dt=meta['dt'], depth=meta['resolution_depth'], bipolar=polarity, phi=phi)
                if SD != {}:
                    result[w][polarity][phi]['SD'] = SD
                    d = list(result[w][polarity][phi]['SD'])
                    th = []
                    e = []
                    q = []
                    for dur in d:
                        th.append(result[w][polarity][phi]['SD'][dur]['th'])
                        e.append(result[w][polarity][phi]['SD'][dur]['e'])
                        q.append(result[w][polarity][phi]['SD'][dur]['q'])
                    result[w][polarity][phi]['s'] = th
                    result[w][polarity][phi]['d'] = d
                    result[w][polarity][phi]['e'] = e
                    result[w][polarity][phi]['q'] = q
                        
    t_stop = time()
    t_elapsed = (t_stop-t_start)/3600
    meta['elapsed_time'] = t_elapsed
    print('Time elapsed: ' + str(t_elapsed))
    filename = str(loc)
    savetoPickle(filename, result, meta)

def parallel_add_durpoints(loc, cell_id, meta, durlist): #add duration points to existing dataset 
    t_start = time()
    durlist = [round(dur, 2) for dur in durlist]
    output = loadPickle('Datafiles/' + str(cell_id) + '/' + str(loc) + '.pkl')
    result = output['data']
    ## remove dur values that are already in the dataset
    tmp = list(durlist)
    for dur in durlist:
        if dur in output['meta']['dlist']:
            tmp.remove(dur)
    meta['dlist'] = durlist  # set the complete list as new dlist in meta
    durlist = tmp  # use only the new values for simulation here
    print('durlist: ', durlist, loc)
    ## Initialize the cell
    if durlist == []:
        print(loc, 'empty durlist')
    else:
        initCell(cell_id)
        
        rec = createRecordingDict()
        
        for w in meta['shapelist']:
            for polarity in meta['polaritylist']:
                for phi in meta['anglelist']:
                    SD = SDdata(durlist, w, rec, x=loc[0], y=loc[1], z=loc[2], dt=meta['dt'], depth=meta['resolution_depth'], bipolar=polarity, phi=phi)
                    
                    result[w][polarity][phi]['SD'] = SD
                        
                    d = list(result[w][polarity][phi]['SD'])
                    d.sort()
                    th = []
                    e = []
                    q = []
                    for dur in d:
                        th.append(result[w][polarity][phi]['SD'][dur]['th'])
                        e.append(result[w][polarity][phi]['SD'][dur]['e'])
                        q.append(result[w][polarity][phi]['SD'][dur]['q'])
                    result[w][polarity][phi]['s'] = th
                    result[w][polarity][phi]['d'] = d
                    result[w][polarity][phi]['e'] = e
                    result[w][polarity][phi]['q'] = q
    
    t_stop = time()
    t_elapsed = (t_stop-t_start)/3600
    meta['elapsed_time'] = t_elapsed
    print('Time elapsed: ' + str(t_elapsed))
    filename = str(loc)
    savetoPickle(filename, result, meta)

def run_and_save_parallel_locSweep_biphasic(loc, cell_id, meta): 
    t_start = time()
    
    ## Initialize the cell
    initCell(cell_id)
    
    rec = createRecordingDict()
    result = {}
    
    for w in meta['shapelist']:
        result[w] = {} 
        for polarity in meta['polaritylist']:
            
            result[w][polarity] = {}
            
            for phi in meta['anglelist']:
                result[w][polarity][phi] = {}
                
                result[w][polarity][phi]
                
                result[w][polarity][phi]['dist'] = round(calcminr(loc[0], loc[1], loc[2])[0], 2)

                # print(str(w) + ', ' + str(polarity) + ', ' + str(phi), ', ', str([loc]))
                
                SD = SDdata(meta['dlist'], w, rec, x=loc[0], y=loc[1], z=loc[2], dt=meta['dt'], depth=meta['resolution_depth'], biphasic = True, ipd=0.05, bipolar=polarity, phi=phi)
                if SD == {}:
                    meta['loclist'].remove(loc) # remove location from loclist if no threshold was found in SDdata function
                else:
                    result[w][polarity][phi]['SD'] = SD
                    d = list(result[w][polarity][phi]['SD'])
                    th = []
                    e = []
                    q = []
                    for dur in d:
                        th.append(result[w][polarity][phi]['SD'][dur]['th'])
                        e.append(result[w][polarity][phi]['SD'][dur]['e'])
                        q.append(result[w][polarity][phi]['SD'][dur]['q'])
                    result[w][polarity][phi]['s'] = th
                    result[w][polarity][phi]['d'] = d
                    result[w][polarity][phi]['e'] = e
                    result[w][polarity][phi]['q'] = q
                        
    t_stop = time()
    t_elapsed = (t_stop-t_start)/3600
    meta['elapsed_time'] = t_elapsed
    print('Time elapsed: ' + str(t_elapsed))
    filename = str(loc)
    savetoPickle_biphasic(filename, result, meta)
