# -*- coding: utf-8 -*-
# # Compute Fermi surface

# ## Imports
# ---

# +
# Import from modules
from modules import system

# Other imports
import kwant
from kwant.linalg import lll
import scipy
import numpy as np
import math
from tqdm import tqdm
import xarray as xr
import collections
import itertools
from copy import copy
import tinyarray as ta
# -

import matplotlib.pyplot as plt
from matplotlib import rc
rc("font", **{"family": "sans-serif", "sans-serif": ["Helvetica"]})
rc("text", usetex=True)
plt.rcParams["lines.linewidth"] = 2
plt.rcParams["font.size"] = 12
plt.rcParams["legend.fontsize"] = 12

# ## Create system
# ---

# +
# Make infinite system
dimensions = dict()
device = system.Device(dimensions)
from modules import functions

bulk = device.make_bulk(
    region="lead",
    hopping_t=functions.hopping_t,
    hopping_gamma3=functions.hopping_gamma3,
    hopping_gamma4=functions.hopping_gamma4,
)
wrapped = kwant.wraparound.wraparound(bulk, keep=None).finalized()
# -

# Compute m from displacement field
D = 2.578608538513666987e+08 # [V/m]
d = 3.48e-10 # [m] distance between layers
t = 2.8 # [eV]
from scipy.constants import e
m = D * d * e / (t * e) / 2

# Set parameters without trigonal warping
params_0 = dict(
    B=0,
    mu=0.,
    dmu=0,
    m=m,
    alpha=0
)
# Set parameters with trigonal warping
params_1 = params_0.copy()
params_1['alpha'] = 1

# ## Bands

# ### Bulk bands

# +
# Generate Hamiltonian in k-space

lat_ndim = 2

# columns of B are lattice vectors
B = np.array(wrapped._wrapped_symmetry.periods).T
# columns of A are reciprocal lattice vectors
A = np.linalg.pinv(B).T

## calculate the bounding box for the 1st Brillouin zone

# Get lattice points that neighbor the origin, in basis of lattice vectors
reduced_vecs, transf = lll.lll(A.T)
neighbors = ta.dot(lll.voronoi(reduced_vecs), transf)
# Add the origin to these points.
klat_points = np.concatenate(([[0] * lat_ndim], neighbors))
# Transform to cartesian coordinates and rescale.
# Will be used in 'outside_bz' function, later on.
klat_points = 2 * np.pi * np.dot(klat_points, A.T)
# Calculate the Voronoi cell vertices
vor = scipy.spatial.Voronoi(klat_points)
around_origin = vor.point_region[0]
bz_vertices = vor.vertices[vor.regions[around_origin]]
# extract bounding box
k_max = np.max(np.abs(bz_vertices), axis=0)


def momentum_to_lattice(k):
    k, residuals = scipy.linalg.lstsq(A, k)[:2]
    if np.any(abs(residuals) > 1e-7):
        raise RuntimeError(
            "Requested momentum doesn't correspond" " to any lattice momentum."
        )
    return k

def ham(k):
    k = momentum_to_lattice(k)
    return wrapped.hamiltonian_submatrix(params={**dict(k_x=k[0], k_y=k[1]), **params})


# -

def hamiltonian_array(syst, params=None, k_x=0, k_y=0, k_z=0, return_grid=False):
    """Evaluate the Hamiltonian of a system over a grid of parameters.
    Parameters:
    -----------
    syst : kwant.Builder object
        The un-finalized kwant system whose Hamiltonian is calculated.
    params : dictionary
        A container of Hamiltonian parameters. The parameters that are
        sequences are used to loop over.
    k_x, k_y, k_z : floats or sequences of floats
        Momenta at which the Hamiltonian has to be evaluated.  If the system
        only has 1 translation symmetry, only `k_x` is used, and interpreted as
        lattice momentum. Otherwise the momenta are in reciprocal space.
    return_grid : bool
        Whether to also return the names of the variables used for expansion,
        and their values.
    Returns:
    --------
    hamiltonians : numpy.ndarray
        An array with the Hamiltonians. The first n-2 dimensions correspond to
        the expanded variables.
    parameters : list of tuples
        Names and ranges of values that were used in evaluation of the
        Hamiltonians.
    Examples:
    ---------
    >>> hamiltonian_array(syst, dict(t=1, mu=np.linspace(-2, 2)),
    ...                   k_x=np.linspace(-np.pi, np.pi))
    >>> hamiltonian_array(sys_2d, p, np.linspace(-np.pi, np.pi),
    ...                   np.linspace(-np.pi, np.pi))
    """
    # Prevent accidental mutation of input
    params = copy(params)

    try:
        space_dimensionality = syst.symmetry.periods.shape[-1]
    except AttributeError:
        space_dimensionality = 0
    dimensionality = syst.symmetry.num_directions

    if dimensionality == 0:
        syst = syst.finalized()

        def momentum_to_lattice(k):
            return {}

    else:
        if len(syst.symmetry.periods) == 1:

            def momentum_to_lattice(k):
                if any(k[dimensionality:]):
                    raise ValueError("Dispersion is 1D, but more momenta are provided.")
                return {"k_x": k[0]}

        else:
            B = np.array(syst.symmetry.periods).T
            A = B @ np.linalg.inv(B.T @ B)

            def momentum_to_lattice(k):
                lstsq = np.linalg.lstsq(A, k[:space_dimensionality], rcond=-1)
                k, residuals = lstsq[:2]
                if np.any(abs(residuals) > 1e-7):
                    raise RuntimeError(
                        "Requested momentum doesn't correspond"
                        " to any lattice momentum."
                    )
                return dict(zip(["k_x", "k_y", "k_z"], list(k)))

        syst = kwant.wraparound.wraparound(syst).finalized()

    changing = dict()
    for key, value in params.items():
        if isinstance(value, collections.abc.Iterable):
            changing[key] = value

    for key, value in [("k_x", k_x), ("k_y", k_y), ("k_z", k_z)]:
        if key in changing:
            raise RuntimeError(
                "One of the system parameters is {}, "
                "which is reserved for momentum. "
                "Please rename it.".format(key)
            )
        if isinstance(value, collections.abc.Iterable):
            changing[key] = value

    def hamiltonian(**values):
        k = [values.pop("k_x", k_x), values.pop("k_y", k_y), values.pop("k_z", k_z)]
        params.update(values)
        k = momentum_to_lattice(k)
        system_params = {**params, **k}
        return syst.hamiltonian_submatrix(params=system_params, sparse=False)

    names, values = zip(*sorted(changing.items()))

    hamiltonians = (
        [hamiltonian(**dict(zip(names, value))) for value in itertools.product(*values)]
        if changing
        else [hamiltonian(k_x=k_x, k_y=k_y, k_z=k_z)]
    )
    size = list(hamiltonians[0].shape)

    hamiltonians = np.array(hamiltonians).reshape(
        [len(value) for value in values] + size
    )

    if return_grid:
        return hamiltonians, list(zip(names, values))
    else:
        return hamiltonians


# +
# Generate an array of Hamiltonians in the k-space grid

ks = np.linspace(-0.15, 0.15, 201)
kxs = ks - bz_vertices[0][0]
kys = ks - bz_vertices[0][1]

hamiltonians_0 = hamiltonian_array(
    bulk,
    params=params_0,
    k_x=kxs,
    k_y=kys
)

hamiltonians_1 = hamiltonian_array(
    bulk,
    params=params_1,
    k_x=kxs,
    k_y=kys
)
# -

# Diagonalize the Hamiltonians
energies_0 = np.linalg.eigvalsh(hamiltonians_0)
energies_1 = np.linalg.eigvalsh(hamiltonians_1)


def spectral_density(w, energies):
    w += 1e-5 * 1j
    G = 1/(w - energies)
    trG = np.sum(G, axis=2)
    return - trG.imag / np.pi


ws = np.linspace(0, 0.06, 500, dtype='complex')

# Compute k-resolved spectral density
kdos0 = []
kdos1 = []
for w in tqdm(ws):
    kdos0.append(spectral_density(w, energies_0))
    kdos1.append(spectral_density(w, energies_1))
kdos0 = np.array(kdos0)
kdos1 = np.array(kdos1)

# Stack and multiply by 4-fold degeneracy (2 valleys and 2 spins)
kdos = 4 * np.stack((kdos0, kdos1))

# Generate dataset
a = 2.46e-10
ds = xr.Dataset(
    data_vars={
        'kdos': (['alpha', 'ws', 'kx', 'ky'], kdos),
    },
    coords={
        'kx': kxs / a,
        'ky': kys / a,
        'ws': ws.real,
        'alpha': [0, 1]
    },
    attrs={'m': m}
)


def cumtrapz(A, dim):
    """Cumulative Simpson's rule (aka Tai's method)

    Notes
    -----
    Simpson rule is given by
        int f (x) = sum (f_i+f_i+1) dx / 2
    """
    x = A[dim]
    dx = x - x.shift(**{dim:1})
    dx = dx.fillna(0.0)
    return ((A.shift(**{dim:1}) + A)*dx/2.0)\
          .fillna(0.0)\
          .cumsum(dim)


# Calculate electron density
bz_area = 3 * np.abs(np.cross(*bz_vertices[[0, 1]]))
density = cumtrapz(ds.kdos.integrate(['kx', 'ky']) / bz_area, dim='ws')

# Compute Fermi energy for experimental electron density and match
n_exp = 2.850312109862672400e16  # [m^-2]
energy_0 = ws[
    np.argmin(np.abs((density.sel(alpha=0, method="nearest") - n_exp)).to_numpy())
].real
energy_1 = ws[
    np.argmin(np.abs((density.sel(alpha=1, method="nearest") - n_exp)).to_numpy())
].real

# Calculate spectral density at the Fermi energy
kdos0 = ds.kdos.sel(alpha=0, ws=energy_0, method='nearest').to_numpy()
kdos1 = ds.kdos.sel(alpha=1, ws=energy_1, method='nearest').to_numpy()
kdos_exp = np.stack([kdos0, kdos1])

# Generate dataset
ds_exp = xr.Dataset(
    data_vars={
        'kdos': (['alpha', 'kx', 'ky'], kdos_exp)
    },
    coords={
        'kx': kxs,
        'ky': kys,
        'alpha': [0, 1]
    },
    attrs={
        'm': m,
        'energy0': energy_0,
        'energy1': energy_1,
        'bz_point': bz_vertices[0]
    }
)

# Inspect Fermi surfaces
ds_exp.kdos.plot(col='alpha', vmax=10)

# Store data
ds_exp.to_netcdf('./data/fermi_surf.nc')
