# -*- coding: utf-8 -*-
"""
Created on Mon Sep 27 16:40:37 2021

@author: francescvarkev
"""
from neuron import h
from config import setParams
import numpy as np


def initCell(cell_id=20, show_gui=False) :
    """
    Initializes the NEURON model.
    Imports a cell from the 'Aberra et al. 2018' library and initializes standard stimulation parameters.

    Parameters
    ----------
    id : int 1-25, optional
        Cell id in the library of 'Aberra et al. 2018'. The default is 20.
        IMPORTANT: To be able to load all cells, NSTACK should be increased to 100000 and NFRAME to 20000. This can be done in: C:\nrn\lib\nrn.defaults.
    -------
    None.

    """
    cell_id = int(cell_id)
    if show_gui == True:
        from neuron import gui
    
    path = h.getcwd()
    h.chdir('AberraEtAl2018')
    h('gui = ' + str(int(show_gui)))
    if not(hasattr(h, 'Ca')) : # check if the mechanism file has been loaded before.  
        h.nrn_load_dll("nrnmech.dll")
    h.load_file("init.hoc")
    
    setParams('AdultHuman')

    h.cell_chooser(cell_id)
    h.chdir(path)
    # call stim_waveform and attach_stim to be able to use stim_amp and stim_time vectors.
    h.stim_waveform(1, 1, -20)
    h.attach_stim()
    
def createRecordingDict():
    rec = {}
    rec['time'] = h.Vector().record(h._ref_t) # create time record vector
    rec['voltage'] = {}
    for sec in h.allsec():
        if h.ismembrane("xtra", sec=sec):
            rec['voltage'][str(sec)] = h.Vector().record(sec(0.5)._ref_v)
            
    return rec
    
    
    
def findExLoc(rec): 
    """
    Function to find the location of activation. Uses the current stimulation set up. 
    Can be used after finding succesful stimulation parameters.
    Sets time and voltage record vectors for each section of the model and advances simulation until a 0 cross of Vm is detected at one section. This section is identified as excitation location.
    
    Parameters
    ----------
    plot : Bool, optional
        If true, the location is highlighted in the NEURON shapeplot shplot. The default is False.

    Returns
    -------
    v : dict
        dictionary containing the recorded Vm for each section.
    t_list : list
        time vector of recorded data.
    loc : string
        name of activated section.
    t_ex : float
        time of excitation (0 crossing).

    """
    t = []
    loc = []
    v = {}
    for vec in rec['voltage']:
        v[vec] = rec['voltage'][vec].to_python()
        if np.where(np.diff(np.sign(v[vec])))[0].size != 0:
            t.append(np.where(np.diff(np.sign(v[vec])))[0][0])
        
    t_ex = min(t)
    for vec in rec['voltage']:
        if np.where(np.diff(np.sign(v[vec])))[0].size != 0:
            if np.where(np.diff(np.sign(v[vec])))[0][0] == t_ex:
                loc.append(vec)
    
    t = rec['time'].to_python()
    
    # t_ex = 0
    # loc = []
    # h.init() #initialize model to t=0 and v_init etc.
    # v = {}
    
    # while rec['time'][-1] < h.tstop: # 
    #     if loc == []: # if the location of excitation has been found, the loop will continue to fill the vectors until tstop.
    #         for sec in rec['voltage']:
    #             if rec['voltage'][sec][-1] >= 0: # if Vm has crossed 0V, add the location to the loc list
    #                 loc.append(sec)
    #                 t_ex = rec['time'][-1] # the time of excitation is the previous time step
        
    #     h.advance()
    
    # for sec in rec['voltage']:
    #     v[sec] = rec['voltage'][sec].to_python()
    
    # t = rec['time'].to_python()
        
    return v, t, loc, t_ex
    

def AddHighlight(sec, highlight_list = [], p = []): 
    """
    |  Uses the lists highlight_list and p to keep count of sections that need to be highlighted. 
    |  After defining all sections, PlotHighlights() can be used to actually plot them. 
    |  Useful in iterative process to highlight multiple locations of excitation

    Parameters
    ----------
    sec : NEURON section
        section to be added to highlight list.

    Returns
    -------
    None.

    """
    x = sec.x3d(0)
    y = sec.y3d(0)
    z = sec.z3d(0)

        
    if highlight_list == []:    
        p = []
    
    highlight_list.append(h.Section(name = 'highlight_list[' + str(len(highlight_list)) + ']'))
    # h('Loc pt3dclear()')
    highlight_list[-1].pt3dadd(x, y, z, 1)
    highlight_list[-1].pt3dadd(x, y, z, 1)
    p.append(h.PointProcessMark(0.5, sec = highlight_list[-1]))   
    
    return highlight_list, p
    
def PlotHighlights(p):
    """
    Highlights all sections listed in global list p (list of pointprocessmarks) in NEURON shape plot

    Returns
    -------
    None.

    """
    
    h.color_plotmax()
    for pointer in p:
        h.shplot.point_mark(pointer,2,"O",5)
        
def calcminr(x0=200, y0=400, z0=0, plot=False):
    """
    Calculates the minimal distance from the given x,y,z location to the neuron

    Parameters
    ----------
    x0 : int, optional
        x-location. The default is 200.
    y0 : int, optional
        y-location. The default is 400.
    z0 : int, optional
        z-location. The default is 0.
    plot : bool, optional
        if true, nearest point of the model is highlighted in shplot. The default is False.

    Returns
    -------
    rmin : float
        min. distance to model.
    secmin : NEURON section
        Nearest section to given location.

    """
    
    rmin = 1e6
    secmin = 0
    for sec in h.allsec():
            if h.ismembrane("xtra", sec=sec):
                for seg in sec :
                    r = np.sqrt((seg.x_xtra - x0)**2 +(seg.y_xtra - y0)**2 + (seg.z_xtra - z0)**2)
                    if r < rmin:
                        rmin = r
                        secmin = sec
    if plot:
        h('objref rMIN')
        h(str(secmin) + ' rMIN = new PointProcessMark(0.5)') # middle of sElec        
        h.color_plotmax()
        h.shplot.point_mark(h.rMIN,2,"O",5) # mark electrode point 
    
    return rmin, secmin
 

def calcr(sec, x0, y0, z0 = 0, plot=False):
    """
    Calculates distance between given NEURON section and x,y,z location

    Parameters
    ----------
    sec : NEURON section
    x0, y0, z0 : float
    plot : bool, optional
        if true, the section is highlighted in shplot. The default is False.

    Returns
    -------
    r : float
        Calculated distance.

    """

    if h.ismembrane("xtra", sec=sec):
        r = np.sqrt((sec.x_xtra - x0)**2 +(sec.y_xtra - y0)**2 + (sec.z_xtra - z0)**2)
    else:
        raise ValueError('section is has no xtra cordinates')
    if plot:
        h('objref rMIN')
        h(str(sec) + ' rMIN = new PointProcessMark(0.5)') # middle of sElec        
        h.color_plotmax()
        h.shplot.point_mark(h.rMIN,2,"O",5) # mark electrode point 

    return r
       
def calcLocspace(rmin, rmax, step, z=False):
    
    xmin = xmax = ymin = ymax = zmin = zmax = 0
    
    for sec in h.allsec():
        if h.ismembrane("xtra", sec=sec):
            for seg in sec:
                if seg.x_xtra < xmin:
                    xmin = seg.x_xtra
                if seg.x_xtra > xmax:
                    xmax = seg.x_xtra
                if seg.y_xtra < ymin:
                    ymin = seg.y_xtra
                if seg.y_xtra > ymax:
                    ymax = seg.y_xtra
                if seg.z_xtra < zmin:
                    zmin = seg.z_xtra
                if seg.z_xtra > zmax:
                    zmax = seg.z_xtra
        
    xmin = round(xmin, -2)
    xmax = round(xmax, -2)
    ymin = round(ymin, -2)            
    ymax = round(ymax, -2)
    zmin = round(zmin, -2)
    zmax = round(zmax, -2)

    xlist = range(int(xmin), int(xmax+step), step)
    ylist = range(int(ymin), int(ymax+step), step)
    if z:
        zlist = range(int(zmin), int(zmax+step), step)
    else:
        zlist = [0]
    loclist = []
    
    for x in xlist:
        for y in ylist:
            for z in zlist:
                r = calcminr(x, y, z)[0]
                if r > rmin and r < rmax:
                    loclist.append((x,y,z))
    
    return xlist, ylist, zlist, loclist  
        
def quitNeuron():
    h.quit()
