# -*- coding: utf-8 -*-
"""
Created on Thu Mar 28 12:07:35 2024

@author: jingla
"""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import scipy.constants as c 
from scipy import optimize
from scipy.optimize import curve_fit
from scipy.signal import find_peaks, peak_prominences, savgol_filter
from plotsv1p0 import My_color_plot
from matplotlib.lines import Line2D
font = {'family' : 'DejaVu Sans',
        'weight' : 'normal',
        'size'   : '15'}
plt.rc('font', **font)
######## Constants
el = 1.602E-19
hbar = 1.05457182E-34
L = 4E-6
B0 = 0.074
Vg0 = 0.0
Cbg =  0.00144


def default_color_cycle():
    # Back to default color cycle
    new_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728',
                  '#9467bd', '#8c564b', '#e377c2', '#7f7f7f',
                  '#bcbd22', '#17becf']
    plt.rcParams['axes.prop_cycle'] = plt.cycler(color=new_colors)
    return

def Fermi_vel_BLG(n):
    '''
    Fermi velocity of bilayer graphene (in m/s) assuming no bandgap opening.
    It only depends on n and accounts for the parabolic structure at low n.

    Parameters
    ----------
    n : float, carrier density

    '''
    vF0 = 1E6
    gamma1 = 0.39 * c.e
    alpha = 1 / (c.pi*c.hbar**2*vF0**2)
    return np.sqrt(2)*vF0*np.sqrt(np.sqrt(1+4*np.abs(n)/(alpha*gamma1**2))-1)

def scattering_rate_collimation(area, area0, L, vF):
    '''
    Calculates the scattering rate using the formula from
    Science 353, 1526 (2016) removing the \pi at the denominator which assumes
    that the trajectories have a length \pi L, which applies to the focusing
    case, but not the collimation, where the electrons are collected directly.

    Parameters
    ----------
    area : area under the peak at the temperature of study
    area0 : area under the peak at base temperature
    L : contact separation in m
    vF : Fermi velocity in m/s

    Returns
    -------
    TYPE
        DESCRIPTION.

    '''
    return -2*vF/(L)*np.log(area/area0)

def my_round(num,dec=1):
    '''
    Rounds numbers in scientific notation and converts them into strings

    Parameters
    ----------
    num : float, number to be rounded
    dec : integer, number of decimals

    Returns
    -------
    out : string representing the rounded number
    '''
    if type(num) != float:
        out = [format(n,'.{}e'.format(dec)) for n in num]
    else:
        out = format(num,'.{}e'.format(dec))
    return out

def fit_func(x, a):
    return a*x**2

def fit_func2(x, a, b, c):
    return a*x**2 + b*x + c    

def processing_scattering_rate(Bpeak, Tout, Areapeak, BmaxL, BminL, Vbg, L, vF,
                               polyfit = False, col2=False):
    '''
    Calculates the scattering rate using the formula from
    Science 353, 1526 (2016). The parameters extracted from fitting 
    are printed but not passed as output

    Parameters
    ----------
    Tout : list of temperatures with the same length as Bpeak 
    Bpeak : list of B fields for each peak at temperature T in Tout and area
    in Areapeak
    Areapeak : list of peak areas
    BmaxL : list with 2 elements. Maximum value of Bpeak to be considered
    BminL : list with 2 elements. Minimum value of Bpeak to be considered
    Vbg : Backgate voltage to analyse
    L : Separation between injector and detector
    vF : Fermi velocity
    polyfit : Boolean. If True the rate is fit to aT^2+bT+c, if False to dT^2 
    col2 : if the plot is the second to be plotted it makes the line of color C2
        instead of C0
    '''
    Bpeak = np.array(Bpeak)
    Tout = np.array(Tout)
    Areapeak = np.array(Areapeak)
    if type(BminL)==float:
            BminL=[BminL]
            BmaxL=[BmaxL]
    for i, (Bmin, Bmax) in enumerate(zip(BminL, BmaxL)):
        if col2:
                i = 1
        filt_Bpeak = np.multiply(Bpeak>Bmin, Bpeak<Bmax)
        AreapeakF = Areapeak[filt_Bpeak]
        TauPB = scattering_rate_collimation(AreapeakF, AreapeakF[0], L, vF)
        plt.plot(Tout[filt_Bpeak],TauPB*1E-12, 'o', 
                 label=' $p={}$'.format(i+1), c='C{}'.format(i*2))
        plt.title('$V_\mathrm{}={}$ V'.format('{bg}',Vbg),fontsize=20)
        if polyfit:
            
            popt, pcov = curve_fit(lambda x, a, b, c: fit_func2(x,a, b, c), Tout[filt_Bpeak], TauPB * 1E-12)
            plt.plot(Tout[filt_Bpeak], fit_func2(Tout[filt_Bpeak], *popt),
                     c='C{}'.format(i*2))
            print('Fit coefficient y=aT^2+bT+c:{}\n Error: {}'.format(popt,np.sqrt(np.diag(pcov))))
            plt.text(20,0,'$\\tau_p^{}=aT^2+bT+c$ \n{}'.format('{-1}',my_round(popt)),
                     fontsize=15)
        else:
            popt, pcov = curve_fit(lambda x, a: fit_func(x,a), Tout[filt_Bpeak],TauPB*1E-12)
            plt.plot(Tout[filt_Bpeak], fit_func(Tout[filt_Bpeak], popt), c='C{}'.format(i*2))
            print('Fit coefficient y=aT^2:{}\n Error: {}'.format(popt,np.sqrt(np.diag(pcov))))
            plt.text(20,0,'$y=aT^2$:{}\n Error: {}'.format(my_round(popt),np.round(np.sqrt(np.diag(pcov)),2)),fontsize=15)
    plt.tick_params(direction='in', top=True, right=True)
    plt.ylabel('$\\tau_p^{-1}$ (ps$^{-1}$)', fontsize=20)
    plt.xlabel('$T$ (K)', fontsize=20)
    return


def Bmax_fun(n):
    return 2*hbar*np.sqrt(np.pi*n)/(el*L)

def area_under_peaks(Bmax, B, R, color='C0', offset=0, plot=True):
    B_filt = np.abs(B) < Bmax
    if np.sum(B_filt) == 0:
        return
    B  = B[B_filt]
    R  = R[B_filt]
    a = (R[0]+R[-1]) / 2
    b = (R[-1]-R[0]) / (2 * Bmax)
    baseline = a + b * B
    A1 = np.trapz(R, B)
    A2 = np.trapz(baseline, B)
    area_under_peak = A1 - A2
    if plot:
        plt.plot(B, R + offset, '-o', c=color)
        plt.plot(B, baseline + offset, '-', c=color)
        plt.fill_between(B, R + offset, baseline + offset,
                         color=color, alpha=0.5)
    return area_under_peak


def plot_collimation_temp(filename, I, Vbgrange = [0],DR=50,DRT=0,
                          Trange=[], areaPlot=False, colorbar=True, fsize=[6.4, 6.5]):
    '''
    

    Parameters
    ----------
    filename : list of strings with the filenames of the .dat files containing
    collimation measurements at different temperatures
    I : float, applied current in Amps.
    Vbgrange : list with two floating point numbers. It defines the range of
    Vbg values to be selected.
    DR : numeric, R offset to be introduced at the output plot between the
    collimation measurements at the same temperature and different Vbg.
    DRT : numeric, R offset to be introduced at the output plot between the
    collimation measurements at different temperatures.
    Trange : list of numerical values containing the minimal and maximal
    temperature we want to plot.
    areaPlot : TYPE, optional
        DESCRIPTION. The default is False.
    colorbar : TYPE, optional
        DESCRIPTION. The default is True, it determines if the colorbar is plotted.
    fsize : TYPE, optional
        DESCRIPTION. Size of the big figure.

    Returns
    -------
    Temps : TYPE
        DESCRIPTION.
    Bmax : TYPE
        DESCRIPTION.
    areas : TYPE
        DESCRIPTION.
    '''
    global Bcor, Rplot
    if type(filename)==str:
        filename = [filename]
    for i, f in enumerate(filename): 
        if i==0:
            data=np.loadtxt(f, skiprows = 28)
        else:
            data1=np.loadtxt(f, skiprows = 28)
            data = np.vstack((data, data1))
    nP, ncol = np.shape(data)
    B, Vbg, Vtg, Vdiff, VdiffY, Vdiff2, Vdiff2Y, time, Temp = data.T
    Vbgs = np.sort(np.unique(Vbg))
    if len(Vbgrange)>1:
        filt_range = np.multiply(Vbgs<max(Vbgrange),Vbgs>min(Vbgrange))
    else: 
        filt_range = np.abs(Vbgs)>Vbgrange
    Vbgs = Vbgs[filt_range]
    filt_T = [np.abs(T-Temp[i-1])<0.0001 for i,T in enumerate(Temp)]
    Temps = np.unique(np.round(Temp[filt_T]))
    if Trange != []:
        Temps = Temps[np.multiply(Temps<np.max(Trange), Temps>np.min(Trange))]
    colormap = plt.cm.get_cmap('jet')
    plt.rcParams['axes.prop_cycle'] = plt.cycler(color=[colormap(k) for k in np.linspace(0, 1, len(Temps))])
    Bmax = []
    offsets = []
    n_list = []
    areas = []
    proms = []
    Bpeaks = []
    temp_proms = []
    color=[colormap(k) for k in np.linspace(0, 1, len(Temps))]
    if colorbar:
        norm = mpl.colors.Normalize(vmin=Temps.min(), vmax=Temps.max())
        cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.jet)
        cmap.set_array([])
    plt.figure(figsize=fsize)
    for j, T in enumerate(Temps):
        filt_T = np.round(Temp,0)==T
        VdiffT = Vdiff[filt_T]
        VbgT = Vbg[filt_T]
        BT = B[filt_T]
        print('Temp={}'.format(T))
        for i, Vb in enumerate(Vbgs):
        
            filt1 = VbgT == Vb
            Binc = BT[filt1][400:800]
            Vdiffinc = VdiffT[filt1][400:800]
            offset = i*DR+j*DRT
            offsets.append(offset)
            n = np.abs(0.001445*(Vb+Vg0)/el)
            n_list.append(n)
            Bmax.append(Bmax_fun(n))
            Bcor = savgol_filter(Binc,5,1)-B0
            Rplot = Vdiffinc/I
            plt.plot(Bcor * 1e3, Rplot+offset, c=cmap.to_rgba(T), linewidth=3,
                     label='$T={}$ K $V_\mathrm{}={}$ V'.format(T,'{bg}',round(Vb,2)))
            plt.plot([-250, 250], [offset, offset], '--', c='gray')
            area = area_under_peaks(Bmax_fun(n), Binc - B0,
                                    Vdiffinc/I, color=color[j],
                                    offset=offset, plot=areaPlot)
            areas.append(area)
            # Peak prominence analysis
            inpeaks = find_peaks((Vdiffinc)/I, prominence=1)
            inpeaks = inpeaks[0]
            for i in inpeaks:
                Bpeaks.append(Binc[i]-B0)
                temp_proms.append(T)
                proms.append(peak_prominences((Vdiffinc)/I,peaks=[i])[0])
            if j==0:
                plt.text(-0.08, (i+0.2)*DR,'$V_\mathrm{}={}$ V'.format('{bg}',round(Vb,2)))
            
            
        fontdict={}
        fontdict['color']=color[j]
        plt.text(0.08, j*DRT+DRT/5, '$T$={} K'.format(T),fontdict)
        ### Appending offset and n values to calculate Bmax
    if len(offsets)>1:
        doffsets = offsets[-1] - offsets[-2]
        offsets.append(offsets[-1] + doffsets)
        offsets.insert(0, offsets[0] - doffsets)
        dn_list = n_list[-1] - n_list[-2]
        n_list.append(n_list[-1] + dn_list)
        n_list.insert(0, n_list[0] + dn_list)
        Bmax.append(Bmax_fun(n_list[-1]))
        Bmax.insert(0, Bmax_fun(n_list[0]))
    plt.fill_betweenx(offsets, Bmax, 0.25, facecolor = 'gray', alpha = 0.25)
    plt.fill_betweenx(offsets, -np.array(Bmax), -0.25, facecolor = 'gray', alpha = 0.25)
    plt.xlim([-0.23, 0.23])
    plt.xlabel('$B$ (mT)', fontsize=20)
    plt.ylabel('$R_\mathrm{nl}$ ($\Omega$)', fontsize=20)
    plt.tick_params(direction='in', top=True, right=True)
    plt.colorbar(cmap)
    return Temps, Bmax, areas, temp_proms, proms, Bpeaks


def Gauss(x, x0, w, amp):
    return amp * np.exp(-(x-x0)**2/(2*w**2))/(w*np.sqrt(2*np.pi))

def Gauss2(x, x0a, x0b, wa, wb, ampa, ampb):
    return Gauss(x, x0a, wa, ampa) + Gauss(x, x0b, wb, ampb)
#%% T dep 2 K to 100 K
filename=['M59_BLGhBN14_I32to02_100nA_V28to03_100x_V32to02_1x.dat',
          'M60_BLGhBN14_I32to02_100nA_V28to03_100x_V32to02_1x.dat',
          'M65_BLGhBN14_I32to02_100nA_V28to03_100x_V32to02_1x.dat',
          'M61_BLGhBN14_I32to02_100nA_V28to03_100x_V32to02_1x.dat',
          'M66_BLGhBN14_I32to02_100nA_V28to03_100x_V32to02_1x.dat']

#plot_collimation_temp(filename, I=1E-7, Vbgrange = [-3.1, 3], DR=50)
Temps1, Bmax1, areas1, temp_proms1, proms1, Bpeaks1 = plot_collimation_temp(filename,
                                                                            I=1E-7,
                                                                            Vbgrange = [-1, 3],
                                                                            DR=0,
                                                                            DRT=0,
                                                                            Trange=[2,75])
plt.xlim([-100, 100])
#plt.legend()
#plt.ylim([-20, 820])
plt.savefig('TdepCollimationVbg1p3V.pdf')
plt.show()

plt.plot(Temps1, areas1, 'o-')
plt.show() 
Temps2, Bmax2, areas2, temp_proms2, proms2, Bpeaks2 = plot_collimation_temp(filename,
                                                                            I=1E-7,
                                                                            Vbgrange = [-3.5, -2.5],
                                                                            DR=0, DRT=0, Trange=[2,75])
plt.xlim([-200, 200])
#plt.legend()
#plt.ylim([-20, 820])
plt.savefig('TdepCollimationVbgm3V.pdf')
plt.show()
#%% Scattering rate from area
default_color_cycle()
Vg = 1.3
n = Cbg * (Vg - Vg0) / c.e
vF1 = Fermi_vel_BLG(n)
Bpeak = [0]*len(Temps1)
plt.figure(figsize=[5.1, 3]) 
processing_scattering_rate(Bpeak[:-7], Temps1[:-7], areas1[:-7], BmaxL=[0.1], #[:-7]
                           BminL=[-0.1],Vbg=Vg, L=L, vF=vF1, polyfit=True)

Vg = -3
n = Cbg * (Vg - Vg0) / c.e
vF2 = Fermi_vel_BLG(n)
processing_scattering_rate(Bpeak, Temps2, areas2, BmaxL=[0.1],
                           BminL=[-0.1],Vbg=Vg, L=L, vF=vF2,
                           polyfit=True, col2=True)

legend_elements = [Line2D([0], [0], marker='o', color='w', label='$V_\mathrm{bg}=1.3$ V',
                          markerfacecolor='C0', markersize=8),
                   Line2D([0], [0], marker='o', color='w', label='$V_\mathrm{bg}=-3$ V',
                          markerfacecolor='C2', markersize=8)]


plt.legend(handles=legend_elements)
plt.savefig('ScatteringRateCollimation.pdf')
plt.show()
#%% Plot area vs temperature
plt.figure(figsize=[5.1, 3])
plt.tick_params(direction='in', top=True, right=True)
plt.plot(Temps1, areas1, 'o-', c='C0')
plt.ylim([0, 3])
plt.ylabel('Area ($\\Omega$T)', fontsize=20)
plt.xlabel('$T$ (K)', fontsize=20)
plt.plot(Temps2, areas2, 'o-', c='C2')
plt.savefig('AreaVsTCollimation.pdf')
