
## needed imports 
import matplotlib.pyplot as plt
from quantify_core.data.handling import load_dataset, to_gridded_dataset, get_datadir
from scipy.signal import savgol_filter, find_peaks
import glob 
import numpy as np


#%% Define the color palette

import numpy as np

# hex (string) to rgb (tuple3)
def hex2rgb(hex):
    hex_cleaned = hex.lstrip('#')
    return tuple(int(hex_cleaned[i:i+2], 16) for i in (0, 2 ,4))

# rgb (tuple3) to hex (string)
def rgb2hex(rgb):
    return '#' + ''.join([str('0' + hex(hh)[2:])[-2:] for hh in rgb])

# weighted mix of two colors in RGB space (takes and returns hex values)
def color_mixer(hex1, hex2, wt1=0.5):
    rgb1 = hex2rgb(hex1)
    rgb2 = hex2rgb(hex2)
    return rgb2hex(tuple([int(wt1 * tup[0] + (1.0 - wt1) * tup[1]) for tup in zip(rgb1, rgb2)]))

# create full palette
def create_palette(start_color, mid_color, end_color, num_colors):
    # set up steps
    # will create twice as many colors as asked for
    # to allow an explicit "mid_color" with both even and odd number of colors
    num_steps = num_colors  
    steps = np.linspace(0, 1, num_steps)[::-1]

    # create two halves of color values
    pt1 = [color_mixer(first_color, mid_color, wt) for wt in steps]
    pt2 = [color_mixer(mid_color,  last_color, wt) for wt in steps[1:]]

    # combine and subsample to get back down to 'num_colors'
    return (pt1 + pt2)[::2]

# the 3 colors you specified
first_color = rgb2hex([29,113,184]) #This is the color of Si
last_color  = rgb2hex([190,22,34])  #This is the color of Ge
mid_color   = '#fefefe'

# create hex colors
result = create_palette(first_color, mid_color, last_color, 100)

from matplotlib.colors import ListedColormap, LinearSegmentedColormap
# # my_rgbs = my_rgbs/254
my_cmap = ListedColormap(result)

#%%

## loads the data
# the directory where the data folder is located
path = r"W:\staff-groups\tnw\ns\qt\ScappucciLab\0_Group members\Davide\6nm_QW_paper\QuantumInspire\ValleyData" 
# data loaded to variable ‘dset’. ‘tuid’ is the id of the measurement found in the data folder name.

tuid = ['20230203-164033-325-464479',  # P1
        '20230208-174113-861-6eeb38',  # P2
        '20230214-235617-893-7f08f8',  # P3
        '20230127-172543-452-c90cb7',  # P4
        '20230203-004000-277-a8d475',  # P5
        '20230217-235220-745-16af9b']  # P6

number = 2

for i in [number]:
    dset = load_dataset(tuid=tuid[i], datadir=path)
 
#%%
# generates a 2d color plot

plt.rcParams['figure.autolayout'] = False
plt.rcParams['figure.constrained_layout.use'] = True
plt.rcParams['pdf.fonttype'] = 42

data = to_gridded_dataset(dset)

x = data.x0.data
y = data.x1.data
z = data.y0.data

n = 120

aux = np.transpose(z)

# pos = np.where(abs(abs(np.gradient(savgol_filter(aux[n],17,2)))) == max(abs(np.gradient(savgol_filter(aux[n],17,2)))))
# pos = np.where(abs(abs(np.gradient(aux[n]))) == max(abs(np.gradient(aux[n]))))

# pos = np.where(  aux[n] > 0.5*(max(aux[n])- min(aux[n])) + min(aux[n])  ) 
tt = savgol_filter(aux[n],17,2)
if np.where(tt == max(tt)) < np.where(tt == min(tt)):
    pos = np.where(  tt < 0.5*(max(tt)- min(tt)) + min(tt)  ) 
else:
    pos = np.where(  tt > 0.5*(max(tt)- min(tt)) + min(tt)  ) 

pos = pos[0][0]

plt.style.use(r'W:\staff-groups\tnw\ns\qt\ScappucciLab\0_Group members\Davide\6nm_QW_paper/Style.mplstyle')

cm = 1/2.54 
fig = plt.figure(figsize=(9*cm, 7.5*cm), dpi=600)
# grad = np.gradient(z, axis = 0)
plt.pcolormesh(y,x,z,cmap=my_cmap)
plt.colorbar()
plt.scatter(y[n],x[pos])

plt.style.use(r'W:\staff-groups\tnw\ns\qt\ScappucciLab\0_Group members\Davide\6nm_QW_paper/Style.mplstyle')
cm = 1/2.54 
fig = plt.figure(figsize=(9*cm, 7.5*cm), dpi=600)

plt.plot(x,aux[n])
plt.plot(x,savgol_filter(aux[n],17,2))
plt.scatter(x[pos],aux[n][pos])

plt.style.use(r'W:\staff-groups\tnw\ns\qt\ScappucciLab\0_Group members\Davide\6nm_QW_paper/Style.mplstyle')

cm = 1/2.54 
fig = plt.figure(figsize=(9*cm, 7.5*cm), dpi=600)
plt.plot(x,np.gradient(savgol_filter(aux[n],17,2)))
plt.scatter(x[pos],np.gradient(savgol_filter(aux[n],17,2))[pos])

#%%
pos = []
for i in range(len(aux)):
    tt = savgol_filter(aux[i],17,2)
    if np.where(tt == max(tt)) < np.where(tt == min(tt)):
        pos_aux = np.where(  tt < 0.5*(max(tt)- min(tt)) + min(tt)  ) 
    else:
        pos_aux = np.where(  tt > 0.5*(max(tt)- min(tt)) + min(tt)  ) 

    pos.append(pos_aux[0][0])
    # pos.append(np.where(  aux[i] > 0.5*(max(aux[i])- min(aux[i])) + min(aux[i])  )[0][0] )
    # pos.append(np.where(abs(abs(np.gradient(aux[i]))) == max(abs(np.gradient(aux[i])))))
    # pos.append(np.where(abs(abs(np.gradient(savgol_filter(aux[i],17,2)))) == max(abs(np.gradient(savgol_filter(aux[i],17,2))))))

plt.style.use(r'W:\staff-groups\tnw\ns\qt\ScappucciLab\0_Group members\Davide\6nm_QW_paper/Style.mplstyle')

cm = 1/2.54 
fig = plt.figure(figsize=(9*cm, 7.5*cm), dpi=600)

plt.pcolormesh(y,x,z,cmap=my_cmap)
xxx = x[pos].flatten()
plt.plot(y,xxx, color = 'black')
# plt.plot(y,savgol_filter(xxx,17,2))

#%%

from scipy.optimize import curve_fit
plt.style.use(r'W:\staff-groups\tnw\ns\qt\ScappucciLab\0_Group members\Davide\6nm_QW_paper/Style.mplstyle')

cm = 1/2.54 
fig = plt.figure(figsize=(9*cm, 7.5*cm), dpi=600)

rere = [[205,315],
        [205,315],
        [100,170]]

a,b = rere[number]

plt.plot(y[a:b],xxx[a:b])

e_charge = 1.60217663e-19 #C
k_b = 1.38064*1e-23 #J/K
T_e = 150*1e-3 #K
alpha = 0.2*e_charge
beta_e = 1/(k_b*T_e)
g = 2
mu_B = 9.2740100783*1e-24 #J/T
k = g*mu_B*beta_e

def VS(B, E_ST, A, C):
    return C + A*np.log((np.exp(0.5*k*B + beta_e*E_ST*1e-6*e_charge )*(np.exp(k*B)+1))
                                   /(np.exp(k*B) + np.exp(2*k*B) + np.exp(k*B+beta_e*E_ST*1e-6*e_charge) + 1 ))

# plt.plot(y[a:b],VS(y[a:b], 150, 0.09, 827.0))
popt,covm = curve_fit(VS, y[a:b], xxx[a:b], p0 = ( 150, 0.09, 827.0))
plt.plot(y[a:b],VS(y[a:b], popt[0], popt[1], popt[2]))

print('Valley Splitting = ', '{:.0f}'.format(popt[0]), '(mu eV)' )

#%%
plt.style.use(r'W:\staff-groups\tnw\ns\qt\ScappucciLab\0_Group members\Davide\6nm_QW_paper/Style.mplstyle')

plt.rcParams['figure.autolayout'] = False
plt.rcParams['figure.constrained_layout.use'] = True
plt.rcParams['pdf.fonttype'] = 42

cm = 1/2.54 
fig = plt.figure(figsize=(10.5/2*cm, 7.5/2*cm), dpi=600)

a,b = [100,310]

plt.pcolormesh(y,x,z,cmap=my_cmap)
plt.colorbar()
plt.plot(y[a:b],xxx[a:b], color= 'black',lw = 0.5)
plt.plot(y[a:b],VS(y[a:b], popt[0], popt[1], popt[2]), color = 'black' )
plt.xlim(0,3)
# plt.ylim(887.2,888)
# plt.colorbar()
plt.xlabel('B (T)')
plt.ylabel('V (mV)')

#%%

plt.style.use(r'W:\staff-groups\tnw\ns\qt\ScappucciLab\0_Group members\Davide\6nm_QW_paper/Style.mplstyle')


cm = 1/2.54 
fig = plt.figure(figsize=(9*cm, 7.5*cm), dpi=600)

plt.pcolormesh(y,x,z,cmap=my_cmap)
# for i in range(len(pos)):
#     plt.scatter(y[pos[i]],x[pos[i]])

#%% Plot the conductance
import numpy as np
aux  = np.gradient(savgol_filter(z,11,2), axis = 0)

plt.rcParams['figure.autolayout'] = False
plt.rcParams['figure.constrained_layout.use'] = True
plt.rcParams['pdf.fonttype'] = 42
plt.style.use(r'W:\staff-groups\tnw\ns\qt\ScappucciLab\0_Group members\Davide\6nm_QW_paper/Style.mplstyle')

cm = 1/2.54 
fig = plt.figure(figsize=(9*cm, 7.5*cm), dpi=600)

plt.pcolormesh(y,x,aux,cmap=my_cmap, vmin=-0.001, vmax=0.001)
plt.colorbar()

#%%

from scipy.signal import savgol_filter, find_peaks

plt.rcParams['figure.autolayout'] = False
plt.rcParams['figure.constrained_layout.use'] = True
plt.rcParams['pdf.fonttype'] = 42
plt.style.use(r'W:\staff-groups\tnw\ns\qt\ScappucciLab\0_Group members\Davide\6nm_QW_paper/Style.mplstyle')

cm = 1/2.54 
fig = plt.figure(figsize=(9*cm, 7.5*cm), dpi=600)

aux  = np.gradient(z, axis = 0)
aux = np.matrix.transpose(aux)
plt.plot(x, aux[50])
pos = np.where(abs(aux[50]) == max(abs(aux[50])))
plt.scatter(x[pos],aux[50][pos])

plt.rcParams['figure.autolayout'] = False
plt.rcParams['figure.constrained_layout.use'] = True
plt.rcParams['pdf.fonttype'] = 42
plt.style.use(r'W:\staff-groups\tnw\ns\qt\ScappucciLab\0_Group members\Davide\6nm_QW_paper/Style.mplstyle')

cm = 1/2.54 
fig = plt.figure(figsize=(9*cm, 7.5*cm), dpi=600)

for i in range(0,50):
    pos = np.where(abs(abs(aux[i])) == max(abs(aux[i])))
    plt.scatter(x[pos],y[i])


#%% valley splitting function

e_charge = 1.60217663e-19 #C
k_b = 1.38064*1e-23 #J/K
T_e = 150*1e-3 #K
alpha = 0.2*e_charge
beta_e = 1/(k_b*T_e)
g = 2
mu_B = 9.2740100783*1e-24 #J/T
k = g*mu_B*beta_e

#Note E_ST in [mueV]
def Vp(B, E_ST):
    return 1/(alpha*beta_e)*np.log((np.exp(0.5*k*B + beta_e*E_ST*1e-6*e_charge )*(np.exp(k*B)+1))
                                   /(np.exp(k*B) + np.exp(2*k*B) + np.exp(k*B+beta_e*E_ST*1e-6*e_charge) + 1 ))

def VS(B, E_ST, A):
    return A*np.log((np.exp(0.5*k*B + beta_e*E_ST*1e-6*e_charge )*(np.exp(k*B)+1))
                                   /(np.exp(k*B) + np.exp(2*k*B) + np.exp(k*B+beta_e*E_ST*1e-6*e_charge) + 1 ))


plt.rcParams['figure.autolayout'] = False
plt.rcParams['figure.constrained_layout.use'] = True
plt.rcParams['pdf.fonttype'] = 42
plt.style.use(r'W:\staff-groups\tnw\ns\qt\ScappucciLab\0_Group members\Davide\6nm_QW_paper/Style.mplstyle')

cm = 1/2.54 
fig = plt.figure(figsize=(9*cm, 7.5*cm), dpi=600)

xx = np.linspace(0,4,100)
plt.plot(xx, Vp(xx,150))
