from modules import semiclassics
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
from scipy import optimize, interpolate, signal, integrate, special

# Import dataset
ds = xr.load_dataset('./data/semiclassical_collimation.nc').sel(quantization=True)

(
    1e3 # Convert to militesla
    * 2 # Distance from one peak to the other
    * ds.B[
        ds.Is.rolling(B=5, center=True) # take rolling average
        .mean()
        .where(ds.B >= 0) # filter points where B > 0
        .sum("valley") # sum over valleys
        .argmax("B") # find where the maximum is
    ].sel(misalignment=0, offset=0, method="nearest") # `sel` is a selector -- if you just inspect the dataset you can see all values and choose different ones 
).plot.scatter(x="experiment", y="B", hue='alpha') # x-axis is back gate, y is field, color is alpha (amount of trigonal warping)
plt.grid()
plt.title("")
plt.xlabel(r"$V_{bg}~[V]$")
plt.ylabel(r"$\Delta B~[mT]$")
plt.show()

# +
# Plot valley-resolved current
# Choose parameters
voltage = 3
offset = 0
misalignment = 0

# One valley
I1 = (
    ds
    # .rolling(B=5, center=True) # take rolling average
    # .mean()
    .Is.sel( # select parameters
        offset=offset,
        experiment=voltage,
        valley=1,
        alpha=1,
        misalignment=misalignment,
        method="nearest",
    )
)
I2 = (
    ds
    # .rolling(B=5, center=True)
    # .mean()
    .Is.sel(
        offset=offset,
        experiment=voltage,
        valley=-1,
        alpha=1,
        misalignment=misalignment,
        method="nearest",
    )
)
Itot = I1+I2
I1 /= Itot.max(["B"]) # normalize
I1.plot(c="r") # plot
# The other valley, same thing
I2 /= Itot.max(["B"])
I2.plot(c="b")
# Plot both valleys together
(I1 + I2).plot(c="k")
plt.xlim(-0.1, 0.1)
plt.ylabel(r"$R_{nl}~[a.u.]$")
plt.title("")
plt.show()

# +
# Plot valley-resolved current
# Same thing as before, but in the absence of trigonal warping
voltage = 3

I1 = (
    ds
    .Is.sel(
        offset=offset,
        experiment=voltage,
        valley=1,
        alpha=0,
        misalignment=misalignment,
        method="nearest",
    )
)
I2 = (
    ds
    .Is.sel(
        offset=offset,
        experiment=voltage,
        valley=-1,
        alpha=0,
        misalignment=misalignment,
        method="nearest",
    )
)
Itot_norm = (I1 + I2).max(['B'])
(I1 / Itot_norm).plot(c="r")
(I2 / Itot_norm).plot(c="b", ls="--")
((I1 + I2) / Itot_norm).plot(c="k")
plt.xlim(-0.1, 0.1)
plt.ylabel(r"$R_{nl}~[a.u.]$")
plt.title("")
plt.show()

# +
# Plot a lot of data to exemplify how to visualize
voltage = 2

I1 = (
    ds.rolling(B=5, center=True)
    .mean()
    .Is.sel(experiment=voltage, valley=1, method="nearest")
)
I2 = (
    ds.rolling(B=5, center=True)
    .mean()
    .Is.sel(experiment=voltage, valley=-1, method="nearest")
)
((I1 + I2)/(I1 + I2).max(['B'])).plot.scatter(
    x="B", y="Is", ec=None, hue="alpha", row="offset", col="misalignment"
)
plt.xlim(-0.1, 0.1)
plt.show()
# -

# Same for other parameters
I1 = ds.rolling(B=5, center=True).mean().Is.sel(offset=0, valley=1, method="nearest")
I2 = ds.rolling(B=5, center=True).mean().Is.sel(offset=0, valley=-1, method="nearest")
((I1 + I2)/(I1 + I2).max(['B'])).plot.scatter(
    x="B", y="Is", ec=None, hue="alpha", row="experiment", col="misalignment"
)
plt.xlim(-0.1, 0.1)
plt.ylabel(r"$R_{nl} / \max{(R_{nl})}~[\Omega]$")
plt.title("")
plt.show()

ds = xr.load_dataset('./data/experimental_fermi.nc')

# +
# Plot Fermi surface for both valleys

import matplotlib.pyplot as plt
voltage=0

fig = plt.figure()
ax = fig.add_subplot(111, polar=True)
ds.sel(valley=1, experiment=voltage, misalignment=100, alpha=1.2, method="nearest").kF.plot(
    c='r'
)
ds.sel(valley=-1, experiment=voltage, misalignment=100, alpha=1.2, method="nearest").kF.plot(
    subplot_kws=dict(polar=True), c='b'
)
plt.title('')
plt.show()


# -
# dIdphi calculator
def compute_dIdphi(valley, voltage, alpha):
    dataset = ds.sel(valley=valley, experiment=voltage,
                     # offset=0,
                     misalignment=0, alpha=alpha, method='nearest')
    B = 1
    kF = dataset.kF.data[:: int(np.sign(B))]
    polar_angles = dataset.phis.data[:: int(np.sign(B))]
    phic = 0
    k = semiclassics.kx_ky(k_F=kF, phi=polar_angles, phic=phic)
    occupation, _ = semiclassics.quantization(
        k=k,
        phic=0,
        U=dataset.U.data,
        valley=dataset.valley.data,
        mu=dataset.mu.data,
        Ec=dataset.Ec.data,
        Ecmax=dataset.Ecmax.data,
        W=dataset.attrs["Wc"],
        alpha=dataset.alpha.data,
    )
    print(occupation)
    r, dr = semiclassics.k_to_r(
        k_F=kF,
        phi=polar_angles,
        phic=phic,
        B=B,
    )
    
    dIdphi, injection_angle = semiclassics.calculate_injection_prob(dr, cos=False)
    
    filter_injection = np.multiply(injection_angle >= 0, injection_angle <= np.pi)
    
    occupation = occupation[:-1][filter_injection]
    injection_angle = injection_angle[filter_injection]
    dIdphi = dIdphi[filter_injection] * occupation
    
    nbins = 300
    y, x = np.histogram(injection_angle, bins=nbins, weights=dIdphi)
    from scipy import ndimage
    
    out = ndimage.gaussian_filter(y, sigma=nbins * 5 / 180)
    smearing = interpolate.interp1d(
        x[1:], out / np.max(out), fill_value=0, bounds_error=False
    )
    
    return injection_angle, np.cos(injection_angle - np.pi / 2) * smearing(injection_angle)


# +
# Calculate dIdphi with and without trigonal warping

import matplotlib.pyplot as plt

voltage = 0
alpha=1.5

injection1, dIdphi1 = compute_dIdphi(valley=1, voltage=voltage, alpha=alpha)
injection2, dIdphi2 = compute_dIdphi(valley=-1, voltage=voltage, alpha=alpha)
injection3, dIdphi3 = compute_dIdphi(valley=1, voltage=voltage, alpha=0)

fig = plt.figure()
ax = fig.add_subplot(111, polar=True)
ax.plot(injection1, dIdphi1 / np.max(dIdphi1), c='r')
ax.plot(injection2, dIdphi2 / np.max(dIdphi2), c='b')
ax.plot(injection3, dIdphi3 / np.max(dIdphi3), c='k', ls='--')
plt.xlim(0, np.pi)
plt.xlabel(r'$\phi$')
plt.ylabel(r'$dI/d\phi$')
plt.show()
# -

ds.experiment


