# -*- coding: utf-8 -*-
"""
Created on Wed Jul 14 16:43:15 2021

Reproduction of the CalcVe.hoc functions in HOC code of Aberra et al. 2018
@author: francescvarkev
"""
import numpy as np
from neuron import h

def calcesI(x0=200, y0=-50, z0=0, sigma_e=2.76e-07, bipolar=0, dist = 200, theta = 90, phi = 0, show_gui = False):
    """
    |  Calculates linear multiplication factor (es_xtra) for each segment in the model for given electrode location, configuration and conduction parameter sigma_e.
    |  ex_xtra is multiplied with the stimulation pulse during simulation. 
    
    Parameters
    ----------
    x0 : float, optional
        x location of (working) electorde. The default is 200.
    y0 : float, optional
        y location of (working) electrode. The default is -50.
    z0 : float, optional
        z location of (working electrode). The default is 0.
    sigma_e : float, optional
        Conductivity of field surrounding model in S/um. The default is 2.76e-07.
    bipolar : bool, optional
        Electrode configuration, 0 = monopolar, 1 = bipolar. The default is 0.
    dist : float, optional
        Distance between two electrodes in case of bipolar configuration. The default is 200.
    theta : float, optional
        Angle theta (polar cordinates) in case of bipolar configuration. The default is 90.
    phi : float, optional
        Angle phi (polar cordinates) in case of bipolar configuration. The default is 0.

    Returns
    -------
    None.

    """
    
    
    if bipolar == 0:
        for sec in h.allsec():
            if h.ismembrane("xtra", sec=sec):
                for seg in sec :
                    r = np.sqrt((seg.x_xtra - x0)**2 +(seg.y_xtra - y0)**2 + (seg.z_xtra - z0)**2)
                    seg.es_xtra = 1e-3/(4*np.pi*sigma_e*r)
        
        if show_gui == True:
            # mark electrode position in shape plot
            h('create sElec')
            h('sElec pt3dclear()')
            h('sElec pt3dadd(' + str(x0) + '-5, '+ str(y0) +', '+ str(z0) +', 1)')
            h('sElec pt3dadd(' + str(x0) + '+5, '+ str(y0) +', '+ str(z0) +', 1)')
            h('objref pElec')
            h('sElec pElec = new PointProcessMark(0.5)') # middle of sElec        
            h.color_plotmax()
            h.shplot.point_mark(h.pElec,2,"O",5) # mark electrode point         		
            h.shplot.label(600,100,"Electrode position",1,1,0,0,2)
    
    if bipolar == 1:
        x1 = round((dist*np.cos(np.deg2rad(phi))*np.sin(np.deg2rad(theta)) + x0), 1)
        y1 = round((dist*np.sin(np.deg2rad(phi))*np.sin(np.deg2rad(theta)) + y0), 1)
        z1 = round((dist*np.cos(np.deg2rad(theta)) + z0), 1)
    
        for sec in h.allsec():
            if h.ismembrane("xtra", sec=sec):
                for seg in sec :
                    r0 = np.sqrt((seg.x_xtra - x0)**2 +(seg.y_xtra - y0)**2 + (seg.z_xtra - z0)**2)
                    r1 = np.sqrt((seg.x_xtra - x1)**2 +(seg.y_xtra - y1)**2 + (seg.z_xtra - z1)**2)
                    seg.es_xtra = (1e-3/sigma_e)*(1/(4*np.pi*r0) - 1/(4*np.pi*r1))
        
        if show_gui == True:
            # mark electrode position in shape plot
            # Working electrode
            h('create sElec_W')
            h('sElec_W pt3dclear()')
            h('sElec_W pt3dadd(' + str(x0) + '-5, '+ str(y0) +', '+ str(z0) +', 1)')
            h('sElec_W pt3dadd(' + str(x0) + '+5, '+ str(y0) +', '+ str(z0) +', 1)')
            h('objref pElec_W')
            h('sElec_W pElec_W = new PointProcessMark(0.5)') # middle of sElec
            # return electrode
            h('create sElec_R')
            h('sElec_R pt3dclear()')
            h('sElec_R pt3dadd(' + str(x1) + '-5, '+ str(y1) +', '+ str(z1) +', 1)')
            h('sElec_R pt3dadd(' + str(x1) + '+5, '+ str(y1) +', '+ str(z1) +', 1)')
            h('objref pElec_R')
            h('sElec_R pElec_R = new PointProcessMark(0.5)') # middle of sElec
            
            h.color_plotmax()
            h.shplot.point_mark(h.pElec_R,3,"O",5) # mark electrode point       		
            h.shplot.label(x1+20,y1+30,"RE",1,1,0,0,3)
            h.shplot.point_mark(h.pElec_W,2,"O",5) # mark electrode point         		
            h.shplot.label(x0+20,y0+30,"WE",1,1,0,0,2)  
                
def calceE(theta, phi):
    """
    Calculate Ve for unit electric field with angles theta and phi.


    """
    
    theta = theta*np.pi/180
    phi = phi*np.pi/180
    Ex = np.sin(theta)*np.cos(phi)
    Ey = np.sin(theta)*np.sin(phi)
    Ez = np.cos(theta)
    for sec in h.allsec():
        if h.ismembrane("xtra", sec=sec):
            for seg in sec :
                seg.es_xtra = -(Ex*seg.x_xtra + Ey*(-seg.z_xtra) + Ez*seg.y_xtra)*1e-3
                

def getes(stimmode, x0 = 200, y0 = -50, z0 = 0, sigma_e = 2.76e-7, theta = 180, phi = 0):
    """
    Calculate unit Ve depending on stimulation mode.
    During simulation Ve is multiplied with stimulation pulse. 
    Point source for stimmode = 1, E-field for stimmode = 0.
    """
    
    if stimmode == 1:
        calcesI(x0, y0, z0, sigma_e)
    elif stimmode == 2:
        calceE(theta, phi)
        