# %%
import time as t
import numpy as np
import scipy.constants as c
import matplotlib.pyplot as plt
from functools import partial
import xarray as xr
import itertools as it
from tqdm import tqdm
# from modules import semiclassics
from modules import dispersion, semiclassics
import xarray as xr
from scipy import optimize, integrate, interpolate

from matplotlib import rc

rc("font", **{"family": "sans-serif", "sans-serif": ["Helvetica"]})
rc("text", usetex=True)
plt.rcParams["lines.linewidth"] = 0.654
plt.rcParams["font.size"] = 15
plt.rcParams["legend.fontsize"] = 15

# %%
ds = xr.load_dataset('./data/experimental_fermi.nc')
params = {
    "B": np.linspace(-0.15, 0.15, 250, endpoint=True),
    "offset": np.array([0, 100e-9, 200e-9]),
    "quantization": [True, False]
}
ds = ds.assign_coords(params)

# %%
from dask_quantumtinkerer import Cluster

cluster = Cluster(
    200, extra_path="~/Work/electron-focusing-blg/theoretical_calculations/"
)
cluster.launch_cluster()

# %%
client = cluster.get_client()

# %%
client

# %%
# values = list(ds.coords.values())[1:]
# keys = list(ds.coords.keys())[1:]

values = list(ds.coords.values())[1:]
values = list(np.array(value, ndmin=1) for value in values)
keys = list(ds.coords.keys())[1:]

args = np.array(list(it.product(*values)))
shapes = [len(value) for value in values]

# %%
keys


# %%
def wrapped_current(n, args, fermi_ds):
    from modules import dispersion, semiclassics
    arg = args[n]
    return semiclassics.collimation_calc(
        dataset=fermi_ds.sel(
            experiment=arg[0],
            alpha=arg[1],
            misalignment=arg[2],
            valley=arg[3],
            method="nearest",
        ),
        B=arg[4],
        offset=arg[5],
        size_quantization=arg[6]
    )


# %%
ds_scattered = client.scatter(ds)
args_scattered = client.scatter(args)
# Run calculation
result_ungathered = [
    client.submit(wrapped_current, n, args_scattered, ds_scattered)
    for n in range(len(args))
]
result = client.gather(result_ungathered)

# %%
Is = np.array(result).reshape(shapes)
_ds = ds.assign(dict(Is=(keys, Is)))

# %%
# Plot valley-resolved current
I = _ds.sum("valley").Is.sel(offset=0, alpha=1.2, method="nearest").sel(quantization=True)
(I / I.max(["B"])).plot(row="experiment", hue="alpha", col="misalignment")
# plt.yscale('log')
plt.xlim(-0.1, 0.1)
plt.show()

# %%
# Plot valley-resolved current
I = _ds.sum("valley").Is.sel(offset=0, alpha=1, method="nearest").sel(quantization=False)
(I / I.max(["B"])).plot(row="experiment", hue="alpha", col="misalignment")
# plt.yscale('log')
plt.xlim(-0.1, 0.1)
plt.show()

# %%
_ds.to_netcdf('./data/semiclassical_collimation.nc')

# %%
# !ssh hpc05 -C "killall dask-quantumtinkerer-server"
