# -*- coding: utf-8 -*-
"""
Created on Fri Jun 18 11:16:35 2021

@author: francescvarkev

Generate different (monopolar) pulse shapes.
waveform(x, c, c0, hc, w) is the generic function to create a pulse with shape 'w'. 
Needs a time instance x. 
"""

import numpy as np
from scipy.fft import fft, fftfreq
#example data for waveforms
from matplotlib import pyplot as plt
from scipy.signal import butter, lfilter, freqz
# a = 0
# b = 600e-6
# fs=1e6
# n = int((b-a)*fs)
# t = np.linspace(a,b,n)

hc = 1
c = 200e-6
c0 = 0

def rect_wave(x,c,c0,hc): 
    '''
    Rectangular pulse
    x = time instance, c = duration, c0 = start time, hc = amplitude
    '''
    if x>=(c+c0):
        r=0.0
    elif x<=c0:
        r=0.0
    else:
        r=hc
    return r

# data=np.array([rect_wave(x,c,c0,hc) for x in t])
# plt.ylim(-0.2,(hc+0.2))
# plt.plot(t,data)
# plt.show()

# 
def triangle_wave(x,c,c0,hc):
    '''
    Centered triangular pulse
    
    x = time instance, c = duration, c0 = start time, hc = amplitude 
    '''
    if x>=(c+c0):
        r=0.0
    elif x<c0:
        r=0.0
    elif x < c0+c/2 and x>=c0:
        r=x*2*hc/c - 2*hc*c0/c 
    else:
        r=-x*2*hc/c+2*hc*c0/c + 2*hc
    return r

# data=np.array([triangle_wave(x,c,c0,hc) for x in t])
# plt.ylim(-0.2,(hc+0.2))
# plt.plot(t,data)
# plt.show()

def sine_wave(x,c,c0,hc): 
    '''
    Sine pulse: half period of sine wave at frequency = 1/2c
    
    x = time instance, c = duration, c0 = start time, hc = amplitude
    '''
    if x>=(c+c0):
        r=0.0
    elif x<c0:
        r=0.0
    else:
        r=hc*np.sin(2*np.pi*(x-c0)/(2*c))
    return r

# data=np.array([sine_wave(x,c,c0,hc) for x in t])
# plt.ylim(-0.2,(hc+0.2))
# plt.plot(t,data)
# plt.show()

def sine_complex(x,c,c0,hc):
    '''
    Complex sine pulse: half period of two sine pulses, one at f = 1.5/2c and one at f = 0.5/2c. 
    
    x = time instance, c = duration, c0 = start time, hc = amplitude
    '''
    if x>=(c+c0):
        r=0.0
    elif x<c0:
        r=0.0
    else:
        r=hc*(np.sin(2*np.pi*(x-c0)*(1.5/(2*c)))+np.sin(2*np.pi*(x-c0)*(0.5/(2*c))))/2
    return r
  
# data=np.array([sine_complex(x,c,c0,hc) for x in t])
# plt.ylim(-0.2,(hc+0.2))
# plt.plot(t,data)
# plt.show()  

def gaus_trunc(x,c,c0,hc, sigma=50e-6):
    '''
    Truncated gaussian pulse with centered mean and scalable sigma
    
    x = time instance, c = duration, c0 = start time, hc = amplitude, sigma = sigma
    '''
    if x>=(c+c0):
        r=0
    elif x<=c0:
        r=0
    else:
        r=hc*np.exp(-((x-c0-c/2)**2/(2*sigma**2)))
    return r

# data=np.array([gaus_trunc(x,c,c0,hc, c/6) for x in t])
# plt.ylim(-0.2,(hc+0.2))
# plt.plot(t,data)
# plt.grid()
# plt.show()

def gaus_trunc_shifted(x,c,c0,hc, sigma=50e-6):
    '''
    Truncated gaussian pulse with shiftable mean and fixed sigma
    Argument shift is called as sigma for easy reuse of runSim functions..
    
    x = time instance, c = duration, c0 = start time, hc = amplitude, sigma = shift
    '''
    shift = sigma
    sigma = 0.1
    if x>=(c+c0):
        r=0
    elif x<=c0:
        r=0
    else:
        r=hc*np.exp(-((x-c0-c*shift)**2/(2*sigma**2)))
    return r

# data=np.array([gaus_trunc(x,c,c0,hc, c/6) for x in t])
# plt.ylim(-0.2,(hc+0.2))
# plt.plot(t,data)
# plt.grid()
# plt.show()

def gaus_wave(x,c,c0,hc):
    '''
    Gausian pulse with pw = 3*sigma and peak at c0 + c/2
    
    x = time instance, c = duration, c0 = start time, hc = amplitude
    '''
    sigma = c/6
    r=hc*np.exp(-((x-c0-c/2)**2/(2*sigma**2)))
    return r

# data=np.array([gaus_wave(x,c,c0,hc) for x in t])
# plt.ylim(-0.2,(hc+0.2))
# plt.plot(t,data)
# plt.grid()
# plt.show()

def sinc_wave(x, c, c0, hc):
    '''
    Centered since pulse with frequency at 2/c (4 lobes at each side)
    
    x = time instance, c = duration, c0 = start time, hc = amplitude
    '''
    if x>=(c+c0):
        r=0
    elif x<=c0:
        r=0
    else:
        r=hc*np.abs(np.sinc(4*np.pi*(x-c0-c/2)/c))
    return r 

# data=np.array([sinc_wave(x,c,c0,hc) for x in t])
# plt.ylim(-0.2,(hc+0.2))
# plt.plot(t,data)
# plt.show()

def HFS(x, c, c0, hc, HF=100, D=0.5):
    '''
    UHF rectangular pulse, requires sufficient sampling frequency for the uhf frequency.
    
    x = time point, c = duration, c0 = start time, hc = amplitude, HF = UHF sampling frequency, D = duty cycle. 
    '''
    N = int(c*HF) # number of UHF pulses in period
    if x > c0 and x <= c0+c:
        for i in range(N):
            if x <= c0+(D+i)/HF and x >= c0+i/HF:
                r = hc
                break
            elif x < c0+(i+1)/HF and x > c0+(D+i)/HF:
                r = 0
                break
    else:
        r = 0
       
    return r

# data=np.array([HFS(x,c,c0,hc) for x in t])
# plt.ylim(-0.2,(hc+0.2))
# plt.plot(t,data)
# plt.show()

def exp_rise(x, c, c0, hc, tau=200e-6):
    '''
    Exponentially rising pulse. With time constant tau and initial max hc.
    
    x = time point, c = duration, c0 = start time, hc = max amplitude, tau = time constant
    '''
    
    if x>=(c+c0):
        r=0
    elif x<=c0:
        r=0
    else:
        r=hc*(np.exp((x-c-c0)/tau))
    return r

# data=np.array([exp_rise(x,c,c0,hc, tau=c) for x in t])
# plt.ylim(-0.2,(5*hc))
# plt.plot(t,data)
# plt.show()

def exp_decay(x, c, c0, hc, tau=200e-6):
    '''
    Exponentially decaying pulse. With time constant tau and initial amplitude hc.
    
    x = time point, c = duration, c0 = start time, hc = initial amplitude, tau = time constant
    '''
    
    if x>=(c+c0):
        r=0
    elif x<=c0:
        r=0
    else:
        r=hc*np.exp(-(x-c0)/tau)
    return r

# data=np.array([exp_decay(x,c,c0,hc, tau=c/2) for x in t])
# plt.ylim(-0.2,(2*hc))
# plt.plot(t,data)
# plt.show()

def ramp_up(x, c, c0, hc):
    
    
    if x>=(c+c0):
        r=0
    elif x<=c0:
        r=0
    else:
        r=hc*(x-c0)/c
    return r

# data=np.array([ramp_up(x,c,c0,hc) for x in t])
# plt.ylim(-0.2,(2*hc))
# plt.plot(t,data)
# plt.show()

def ramp_down(x, c, c0, hc):
    
    
    if x>=(c+c0):
        r=0
    elif x<=c0:
        r=0
    else:
        r=hc*(1 - (x-c0)/c)
    return r

# data=np.array([ramp_up(x,c,c0,hc) for x in t])
# plt.ylim(-0.2,(2*hc))
# plt.plot(t,data)
# plt.show()



def opt_wave(x, c, c0, hc, tau, x1):
    '''
    Optimal waveform determined analytically. 
    
    x = time point, c = duration, c0 = start time, hc = max amplitude, tau = time constant, x1 = ratio between membrane parameters.
    '''
    
    if x>=(c+c0):
        r=0
    elif x<=c0:
        r=0
    else:
        r=hc*((np.exp((x-c0)/tau)+np.exp((c0-x)/tau))+x1*(np.exp((x-c0)/tau)-np.exp((c0-x)/tau)))/((np.exp((c)/tau)+np.exp((-c)/tau))+x1*(np.exp((c)/tau)-np.exp((-c)/tau)))
    return r
    
    
# def waveform(c, c0, hc, w, dt=None, t0=0, tstop=None, HF=None, D=0.5, tau=None, sigma=None, biphasic = False, ipd= None, ratio = 1):
#     """
#     Generic function for generating defined pulses. Shape is defined by input variable w.

#     Parameters
#     ----------
#     x : float
#         Time point.
#     c : float
#         Pulse duration.
#     c0 : float
#         Start time.
#     hc : float
#         Pulse amplitude.
#     w : string
#         Pulse shape:
#             |  'r': rectangular
#             |  't': triangular
#             |  's': sine
#             |  'sc': complex sine
#             |  'g': gaussian
#             |  'snc': sinc
#             |  'HFS': high frequency sampled pulse
#             |  'er': exponentially rising
#             |  'ed': exponentially decaying
#     HF : float, optional
#         Sampling frequency in case of a HFS pulse. The default is 100.
#     D : float, optional
#         Duty cycle in case of a HFS pulse. The default is 0.5.

#     Returns
#     -------
#     x : float
#         Amplitude of the pule at time instance x.
        
#     Example Usage:
#     ------
#     |  from matplotlib import pyplot as plt
#     |  a = 0
#     |  b = 600e-6
#     |  fs=1e6
#     |  n = int((b-a)*fs)
#     |  t = np.linspace(a,b,n)
#     |  
#     |  hc = 1
#     |  c = 100e-6
#     |  c0 = 0
#     |  w = 'r' 
#     |  
#     |  data = np.array([waveform(x, c, c0, hc, w) for x in t])
#     |  plt.plot(t, data)
#     |  plt.show()
#     """
#     if ipd is None:
#         ipd = c/10
        
#     if tstop is None:
#         if biphasic is False:
#             tstop = 2*c + c0
#         else:
#             tstop = (2+ratio)*c + ipd + c0

#     if dt is None:
#         dt = c/1e3
    
#     if tau is None:
#         tau = c/6
    
#     if sigma is None:
#         sigma = c/6
        
#     if HF is None:
#         HF = 100/c
        
#     n = int((tstop-t0)/dt)
#     t = np.linspace(t0, tstop, n)
#     wave = np.array([])
#     phase = 0
#     for x in t:
#         if biphasic is True:
#             if x >= (c + c0 + ipd) and phase == 0:
#                 c0 = c + c0 + ipd
#                 hc = -hc/ratio
#                 c = c*ratio
#                 phase = 1
                
#         if w == 'r':
#             wave = np.append(wave, rect_wave(x,c,c0,hc))
#         elif w == 't':
#             wave = np.append(wave, triangle_wave(x, c, c0, hc))
#         elif w == 's':
#             wave = np.append(wave, sine_wave(x, c, c0, hc))
#         elif w =='sc':
#             wave = np.append(wave, sine_complex(x, c, c0, hc))
#         elif w == 'g':
#             wave = np.append(wave, gaus_wave(x, c, c0, hc))
#         elif w == 'gt':
#             wave = np.append(wave, gaus_trunc(x, c, c0, hc, sigma))
#         elif w == 'gts':
#             wave = np.append(wave, gaus_trunc_shifted(x, c, c0, hc, sigma))
#         elif w == 'snc':
#             wave = np.append(wave, sinc_wave(x, c, c0, hc))
#         elif w == 'HFS':
#             wave = np.append(wave, HFS(x, c, c0, hc, HF, D))
#         elif w == 'er':
#             wave = np.append(wave, exp_rise(x, c, c0, hc, tau))
#         elif w == 'ed':
#             wave = np.append(wave, exp_decay(x, c, c0, hc, tau))
#         elif w == 'ru':
#             wave = np.append(wave, ramp_up(x, c, c0, hc))
#         else:
#             print('w does not match')
#             break
        
#     return wave, t

def butter_lowpass(cutoff, dt, order=1):
    nyq = 0.5 / dt
    normal_cutoff = cutoff / nyq
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    return b, a


def butter_lowpass_filter(data, cutoff, dt, order=1):
    b, a = butter_lowpass(cutoff, dt, order=order)
    y = lfilter(b, a, data)
    return y

def waveform(c, c0, hc, w, dt=None, t0=0, tstop=None, HF=None, D=0.5, tau=None, sigma=None, biphasic = False, ipd= None, ratio = 1, filtered = False, cutoff = 1):
    """
    Generic function for generating defined pulses. Shape is defined by input variable w.

    Parameters
    ----------
    x : float
        Time point.
    c : float
        Pulse duration.
    c0 : float
        Start time.
    hc : float
        Pulse amplitude.
    w : string
        Pulse shape:
            |  'r': rectangular
            |  't': triangular
            |  's': sine
            |  'sc': complex sine
            |  'g': gaussian
            |  'snc': sinc
            |  'HFS': high frequency sampled pulse
            |  'er': exponentially rising
            |  'ed': exponentially decaying
    HF : float, optional
        Sampling frequency in case of a HFS pulse. The default is 100.
    D : float, optional
        Duty cycle in case of a HFS pulse. The default is 0.5.

    Returns
    -------
    x : float
        Amplitude of the pule at time instance x.
        
    Example Usage:
    ------
    |  from matplotlib import pyplot as plt
    |  a = 0
    |  b = 600e-6
    |  fs=1e6
    |  n = int((b-a)*fs)
    |  t = np.linspace(a,b,n)
    |  
    |  hc = 1
    |  c = 100e-6
    |  c0 = 0
    |  w = 'r' 
    |  
    |  data = np.array([waveform(x, c, c0, hc, w) for x in t])
    |  plt.plot(t, data)
    |  plt.show()
    """
    if ipd is None:
        ipd = c/10
        
    if tstop is None:
        if biphasic is False:
            tstop = 2*c + c0
        else:
            tstop = (2+ratio)*c + ipd + c0

    if dt is None:
        dt = c/1e3
    
    if tau is None:
        tau = c/6
    
    if sigma is None:
        sigma = c/6
        
    if HF is None:
        HF = 100/c
        
    n = int((tstop-t0)/dt)
    t = np.linspace(t0, tstop, n)
    wave = np.zeros(len(t))
    phase = 0
    for i in range(len(t)):
        x = t[i]
        if biphasic is True:
            if x >= (c + c0 + ipd) and phase == 0:
                c0 = c + c0 + ipd
                hc = -hc/ratio
                c = c*ratio
                phase = 1
                
        if w == 'r':
            wave[i] = rect_wave(x,c,c0,hc)
        elif w == 't':
            wave[i] = triangle_wave(x, c, c0, hc)
        elif w == 's':
            wave[i] = sine_wave(x, c, c0, hc)
        elif w =='sc':
            wave[i] = sine_complex(x, c, c0, hc)
        elif w == 'g':
            wave[i] = gaus_wave(x, c, c0, hc)
        elif w == 'gt':
            wave[i] = gaus_trunc(x, c, c0, hc, sigma)
        elif w == 'gts':
            wave[i] = gaus_trunc_shifted(x, c, c0, hc, sigma)
        elif w == 'snc':
            wave[i] = sinc_wave(x, c, c0, hc)
        elif w == 'HFS':
            wave[i] = HFS(x, c, c0, hc, HF, D)
        elif w == 'er':
            wave[i] = exp_rise(x, c, c0, hc, tau)
        elif w == 'ed':
            wave[i] = exp_decay(x, c, c0, hc, tau)
        elif w == 'ru':
            wave[i] = ramp_up(x, c, c0, hc)
        elif w == 'rd':
            wave[i] = ramp_down(x, c, c0, hc)
        else:
            print('w does not match')
            break
        
    if filtered is True:
        wave = butter_lowpass_filter(wave, cutoff, dt)
    return wave, t
    