# -*- coding: utf-8 -*-
"""
Created on Thu Aug  5 13:07:21 2021

@author: ddegliesposti
"""

import glob
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit
from scipy.signal import savgol_filter, find_peaks
import matplotlib as mpl

#Figures tricks
plt.rcParams['figure.dpi'] = 230 #Note: this is helpuful for high dpi displays such as 4k display on Windows
plt.rcParams['figure.autolayout'] = True
plt.rc('font', size=12)
plt.rcParams['text.usetex'] = False

def function(x,a,c):
    return a*x**c

def lin(x,a,c):
    return a*x + c

def power_lorantian(x,a,c1,B,x0):
    return a*x**c1 + B/((x/x0)**2 + 1)

#defining the PSD function

def PSD(array, time_step):
    '''
    Calculates the power spectral density of an array of samples
    that are evenly spaced in time
    array: sample array
    time_step: scalar, time between two points
    Returns:
    f_axis: array, positive frequencies for which the PSD is
    calculated, without the dc component
    psd: array, values of the PSD at the frequency
    points f_axis
   
    '''
    f_axis = np.fft.fftfreq(len(array), time_step)
    idx = np.argsort(f_axis)
    f_axis = f_axis[idx]
    period=time_step*len(array)
    psd=time_step*time_step*(np.abs(np.fft.fft(array)))**2/period
    psd=psd[idx]
    return f_axis[int(len(psd)/2):][1:], psd[int(len(psd)/2):][1:]


#%% Loading alla the files

from core_tools.data.ds.data_set import load_by_id, load_by_uuid
import numpy as np
import matplotlib.pyplot as plt

#Figures tricks

Amp = 100e6
plt.rcParams['figure.dpi'] = 230 #Note: this is helpuful for high dpi displays such as 4k display on Windows


slope = [2.948166077738232e-07,
         4.206149236073817e-07,
         4.867572369762116e-07,
         8.209239689280597e-07,
         9.9001291095259e-07,
         1.1909849827375159e-06,
         1.5493683272635784e-06,
         1.579417394406175e-06,
         1.9833832602393895e-06,
         2.0668306738091833e-06
          ]

l_arm = [0.348757487,   # in eV/V
         0.323528773,
         0.316750456,
         0.276639482,
         0.27783133,
         0.233213263,
         0.241645396,
         0.215836589,
         0.222497763,
         0.22158971
         ]

uuid_slope = [1675363386167255110,
              1675367768323255110,
              1675372198705255110,
              1675376755117255110,
              1675381186103255110,
              1675385624022255110,
              1675390214482255110,
              1675394922686255110,
              1675399394208255110,
              1675404018132255110]

#%% data new
from core_tools.data.ds.ds_hdf5 import save_hdf5_uuid, load_hdf5_uuid
peak = 9

control_fit=[1,0,1,1,1,1,1,1,1,1,1]

ds = load_hdf5_uuid(int(uuid_slope[peak]),  r'W:\staff-groups\tnw\ns\qt\ScappucciLab\0_Group members\Davide\6nm_QW_paper\QuantumDots\1_Dev1\TimeTraces')

time = ds.m1.x() #
Isd = ds.m1()/Amp # Idc in 

aux = []
data = Isd #convert everything in A

num_segments = 10

time_step = time[1]-time[0] #in seconds

segment = []
for i in range(num_segments):
    segment.append( data[i*int(len(data)/num_segments): (i+1)*int(len(data)/num_segments)] )

new_x = time[0: int(len(data)/num_segments)] 

psd_single = []
freq_single = []

for i in range(len(segment)):
    frequencies, psd = PSD(segment[i],time_step)
    psd_single.append(psd)
    freq_single.append(frequencies)

freq_single = np.array(freq_single)
psd_single = np.array(psd_single)


aux = np.zeros(len(psd_single[0]))
for i in range(len(psd_single)):
    aux = aux + psd_single[i]
    
psd_avg = aux/len(psd_single)

PSD_epsilon = l_arm[peak]**2*psd_avg/(slope[peak]**2)

plt.style.use(r'W:\staff-groups\tnw\ns\qt\ScappucciLab\0_Group members\Davide\6nm_QW_paper/Style.mplstyle')
plt.rcParams['pdf.fonttype'] = 42

cm = 1/2.54 
a,b = [0,2600]
fig = plt.figure(figsize=(9/2*cm, 8.5/2*cm), dpi=600)
plt.scatter(frequencies[a:b], PSD_epsilon[a:b], color = my_cmap(0.0), s=0.2)

plt.xscale('log')
plt.yscale('log')
plt.xlabel('f (Hz)')
plt.ylabel('S$_{\epsilon}$ (eV$^2$/Hz)')
plt.xlim(1/45,45)
plt.ylim(1e-15,1e-10)

ai = [5,5,5,5,30,7,15,20,20]
bi = [2500,2500,2500,2500,2500,2500,2200,2500,2500]

if control_fit[peak] == 0:
    
    # 1/f function
    
    def power(x,a,c):
        return a*x**c
    
    a,b = [30,2600]
    popt,covm = curve_fit(power, frequencies[a:b],PSD_epsilon[a:b], p0= (1e-14,-1.0))
    plt.plot(frequencies, power(frequencies,popt[0],popt[1]), color = 'black')
    
    print('Alpha = ' , "{:.2F}".format(-popt[1]) , '+-' , "{:.2F}".format(-np.sqrt(covm[1][1])/popt[1]) )
    print('Charge noise = ', "{:.3F}".format(np.sqrt(function(1,popt[0],popt[1]))*1e6))
    
    string = 'S$_{\epsilon}^{1/2}$ = '+ "{:.3F}".format(np.sqrt(function(1,popt[0],popt[1]))*1e6)   +' (eV/Hz$^{1/2}$)'
    plt.text(0.04 , 0.4e-14 , string, fontsize = 6 )
    
    # string = 'peak_'+str(peak)+'_S_'+"{:.3F}".format(np.sqrt(function(1,popt[0],popt[1]))*1e6) +'_alpha_'+"{:.2F}".format(-popt[1])+'_f0_NO.pdf'  
    # plt.savefig(fname = string)
    
else:
    
    # Lorentian + 1/f
    
    def power_lorantian(x,a,c1,B,x0):
        return a*x**c1 + B/((x/x0)**2 + 1)
    
    a,b = [3,2200]
    popt,covm = curve_fit(power_lorantian, frequencies[a:b],PSD_epsilon[a:b], p0= (1e-14,-1.0, 0.4e-12 , 5), maxfev = 120000)
    plt.plot(frequencies, power_lorantian(frequencies,popt[0],popt[1],popt[2],popt[3]), color = 'black')
    
    print('Alpha = ' , "{:.2F}".format(-popt[1]) , '+-' , "{:.2F}".format(-np.sqrt(covm[1][1])/popt[1]) )
    print('Charge noise = ', "{:.3F}".format(np.sqrt(power_lorantian(1,popt[0],popt[1],popt[2],popt[3]))*1e6))
    print('f0 = ', "{:.3F}".format(popt[3]))

    string = 'S$_{\epsilon}^{1/2}$ = '+ "{:.3F}".format(np.sqrt(power_lorantian(1,popt[0],popt[1],popt[2],popt[3]))*1e6)   +' (eV/Hz$^{1/2}$)'
    plt.text(0.04 , 0.4e-14 , string, fontsize = 6 )
    
    string = 'peak_'+str(peak)+'_S_'+"{:.3F}".format(np.sqrt(power_lorantian(1,popt[0],popt[1],popt[2],popt[3]))*1e6)+'_alpha_'+"{:.2F}".format(-popt[1])+'_f0_'+"{:.3F}".format(popt[3])+'.pdf'  
    plt.savefig(fname = string)
    

#%% Define the color palette

import numpy as np

# hex (string) to rgb (tuple3)
def hex2rgb(hex):
    hex_cleaned = hex.lstrip('#')
    return tuple(int(hex_cleaned[i:i+2], 16) for i in (0, 2 ,4))

# rgb (tuple3) to hex (string)
def rgb2hex(rgb):
    return '#' + ''.join([str('0' + hex(hh)[2:])[-2:] for hh in rgb])

# weighted mix of two colors in RGB space (takes and returns hex values)
def color_mixer(hex1, hex2, wt1=0.5):
    rgb1 = hex2rgb(hex1)
    rgb2 = hex2rgb(hex2)
    return rgb2hex(tuple([int(wt1 * tup[0] + (1.0 - wt1) * tup[1]) for tup in zip(rgb1, rgb2)]))

# create full palette
def create_palette(start_color, mid_color, end_color, num_colors):
    # set up steps
    # will create twice as many colors as asked for
    # to allow an explicit "mid_color" with both even and odd number of colors
    num_steps = num_colors  
    steps = np.linspace(0, 1, num_steps)[::-1]

    # create two halves of color values
    pt1 = [color_mixer(first_color, mid_color, wt) for wt in steps]
    pt2 = [color_mixer(mid_color,  last_color, wt) for wt in steps[1:]]

    # combine and subsample to get back down to 'num_colors'
    return (pt1 + pt2)[::2]

# the 3 colors you specified
first_color = rgb2hex([29,113,184]) #This is the color of Si
last_color  = rgb2hex([190,22,34])  #This is the color of Ge
mid_color   = '#fefefe'

# create hex colors
result = create_palette(first_color, mid_color, last_color, 100)

from matplotlib.colors import ListedColormap, LinearSegmentedColormap
# # my_rgbs = my_rgbs/254
my_cmap = ListedColormap(result)
