import xarray as xr
import numpy as np
import itertools as it
from modules import dispersion, semiclassics

data = np.loadtxt('./data/ElectricFieldsAndDensities.txt')

vectorized_fun = np.vectorize(dispersion.compute_mu_and_U)

experiment_ds = xr.Dataset(
    data_vars=dict(
        D=(
            ['experiment'],
            data[:,3]
        ),
        D_max=(
            ['experiment'],
            data[:,2]
        ),
        n_exp=(
            ['experiment'],
            data[:,5]
        ),
        n_max=(
            ['experiment'],
            np.zeros(data.shape[0])
        )
    ),
    coords={
        'phis': np.linspace(0, 2 * np.pi, 40000),
        'experiment': data[:,0],
        'misalignment': -np.array([0, 5, 10]) * np.pi / 180,
        'valley': [-1, 1],
        'alpha': np.linspace(0, 1.2, 7)
    },
    attrs={
        'L': 4e-6,
        'Wc': 50e-9
    }
)

mus, Us, E_c = vectorized_fun(experiment_ds.n_exp, experiment_ds.D)
_, U_max, Ec_max = vectorized_fun(experiment_ds.n_max, experiment_ds.D_max)

ds = experiment_ds.assign(
    {
        'mu': ('experiment', mus.data),
        'U': ('experiment', Us.data),
        'Umax': ('experiment', U_max.data),
        'Ec': ('experiment', E_c.data),
        'Ecmax': ('experiment', Ec_max.data)
    }
)

# Reshape params
values = list(ds.coords.values())[1:]
keys = list(ds.coords.keys())[1:]
args = np.array(list(it.product(*values)))
shapes = [len(values[i]) for i in range(len(values))]

keys

# +
from dask_quantumtinkerer import Cluster

cluster = Cluster(
    110, extra_path="~/Work/electron-focusing-blg/theoretical_calculations/"
)
cluster.launch_cluster()
# -

client = cluster.get_client()

client

keys


# +
# Wrapper for current calculation
def wrapped_fermi(arg, ds):
    from modules import dispersion, semiclassics

    _ds = ds.sel(
        experiment=arg[0],
        misalignment=arg[1],
        valley=arg[2],
        alpha=arg[3],
        method='nearest'
    )
    return semiclassics.k_F(
        phi=ds.phis.data,
        phic=_ds.misalignment.data,
        valley=_ds.valley.data,
        mu=_ds.mu.data,
        U=_ds.U.data,
        alpha=_ds.alpha.data,
    )


ds_scattered = client.scatter(ds)
result_ungathered = [client.submit(wrapped_fermi, arg, ds_scattered) for arg in args]
result = client.gather(result_ungathered)
# -

kF = np.array(result)
kF = kF.reshape(*shapes, kF.shape[-1])

ds = ds.update(
    {
        "kF": (list(ds.coords.keys()), np.transpose(kF, (4, 0, 1, 2, 3))),
    }
)

# +
import matplotlib.pyplot as plt

fig = plt.figure()
ax = fig.add_subplot(111, polar=True)
ds.sel(alpha=1.2, valley=1, experiment=0, misalignment=0, method="nearest").kF.plot(
    ax=ax
)
plt.show()
# -
ds.to_netcdf('./data/experimental_fermi.nc')

# !ssh hpc05 -C "killall dask-quantumtinkerer-server"
