from functools import partial
import numpy as np
import xarray as xr
from scipy import optimize, interpolate, signal, integrate, special
from scipy.sparse.linalg import eigsh
import scipy.constants as c
from . import dispersion
import matplotlib.pyplot as plt

# Lattice constant
a = 1.42e-10 * np.sqrt(3)


def k_F(phi, phic, valley, mu, U, alpha):
    """
    Computes Fermi surface for a given dispersion and chemical potential.

    Parameters:
    -----------

    phi: float
        Polar angle.
    phic: float
        Crystal orientation angle.
    valley: int (-1, 1)
        Valley quantum number.
    mu: float
        Chemical potential.
    U: float
        Interlayer imbalance.
    alpha: float
        Degree of trigonal warping.

    Returns
    -------
    array
        Fermi surface in polar coordinates.
    """
    return np.array(
        [
            optimize.fsolve(
                func=dispersion.epsilon,
                x0=0.1,
                args=(angle, phic, U, valley, mu, alpha),
            )[0]
            for angle in phi
        ]
    )


def kx_ky(k_F, phi, phic):
    """
    Converts Fermi surface from polar to cartesian coordinates.

    Parameters:
    -----------

    phi: float
        Polar angle.
    kF: float
        Fermi momentum for corresponding polar angle.
    phic: float
        Crystal orientation angle.

    Returns
    -------
    array
        Fermi surface in cartesian coordinates.
    """
    k = np.array([k_F * np.cos(phi - phic), k_F * np.sin(phi - phic)])
    return k


def k_to_r(k_F, phi, B, phic):
    """
    Converts the Fermi surface to a real-space trajectory according to B.

    Parameters:
    -----------

    phi: float
        Polar angle.
    kF: float
        Fermi momentum for corresponding polar angle.
    phic: float
        Crystal orientation angle.
    B: float
        Magnetic field in tesla.

    Returns
    -------
    r: array
        Cyclotron orbit in cartesian coordinates.
    dr: array
        Real-space displacement between neighboring polar angles.
    """
    phi -= phic
    field_scaled = c.hbar / a / c.e / B
    r_polar = field_scaled * k_F
    r = np.array([r_polar * np.cos(phi - np.pi / 2), r_polar * np.sin(phi - np.pi / 2)])

    dr = np.diff(r, axis=1)
    return r, dr


def calculate_injection_prob(dr, cos=True):
    """
    Computes the injection probabilities for a given state for
    an edge according to its normal angle.

    Parameters:
    -----------

    dr: nd-array
        Real space displacement at injection.
    cos: boolean
        If True, includes cosine law.
    """

    S = np.linalg.norm(dr, axis=0)
    prob = S / np.max(S)
    theta = np.arctan2(dr[1], dr[0])
    if cos:
        prob *= np.cos(theta - np.pi / 2)
    prob[prob < 0] = 0
    prob = prob / np.sum(prob)
    return prob, theta


def dispersion_cut(ky, kx, U, mu, valley, phic, alpha):
    """
    Wrapper to compute a cut of the dispersion for a fixed kx or ky.

    Parameters:
    -----------

    kx: float or nd array
        x-coordinate of Fermi momentum.
    ky: float or nd array
        y-coordinate of Fermi momentum.
    U: float
        Layer imbalance.
    mu: float
        Chemical potential.
    valley: int (-1 or 1)
        Valley quantum number.
    phic: float
        Crystal orientation angle.
    alpha: float
        Degree of trigonal warping.

    Returns
    -------
    array
        Dispersion cut for fixed ky.
    """
    k = np.sqrt(kx**2 + ky**2)
    phi = np.arctan2(ky, kx)
    es = dispersion.epsilon(k, phi, phic, U, valley, mu, alpha)
    return es


def compute_wf(kx, es, pot_conc):
    """
    Computes wavefunction in momentum space for a fixed ky.

    Parameters:
    -----------

    kx: nd array
        x-coordinate of momentum.
    es: nd array
        Dispersion for the corresponding kx and fixed ky.
    pot_conc: float
        Concavity of the electrostatic potential.

    Returns
    -------
    array
        Probability density of the corresponding wavefunction.
    """
    N = len(kx)
    from scipy import sparse

    diags = np.array([np.ones(N - 1), np.ones(N - 1)])
    Vkx = 2 * sparse.eye(N) - sparse.diags(diags, offsets=(-1, 1))
    Vkx /= np.diff(kx)[0] ** 2
    Vkx *= pot_conc
    Ekx = sparse.diags(es)
    H = Vkx + Ekx

    vals, vecs = eigsh(H, sigma=0, k=1)

    psi = vecs[:, 0]
    psi /= np.sqrt(integrate.trapz(np.abs(psi) ** 2, kx))
    rho = np.abs(psi) ** 2

    return interpolate.interp1d(kx, rho, fill_value=0, bounds_error=False)


def quantization(k, phic, U, valley, mu, W, alpha, Ec, Ecmax):
    """
    Computes Fermi surface occupation.

    Parameters:
    -----------

    k: nd array
        Momentum in cartesian coordinates.
    es: nd array
        Dispersion for the corresponding kx and fixed ky.
    pot_conc: float
        Concavity of the electrostatic potential.
    U: float
        Layer imbalance.
    mu: float
        Chemical potential.
    valley: int (-1 or 1)
        Valley quantum number.
    W: float
        Width of the injector.
    alpha: float
        Degree of trigonal warping.
    Ec: float
        Minimum of the valence band in the QPC.
    Ecmax: float
        Minimum of the valence band in the depleted region.

    Returns
    -------
    array
        Occupation of the Fermi surface.
    """
    pot_conc = 4 * (Ecmax + mu - Ec) / (W / a) ** 2
    nu = 1 / 2

    kx_lim = [np.min(k[0]), np.max(k[0])]

    N = 1001
    _kx = np.sort(3 * np.linspace(*kx_lim, N))
    _ky = np.linspace(0, np.max(k[1]), N)
    _kxx, _kyy = np.meshgrid(_kx, _ky)
    es = dispersion_cut(
        kx=_kxx, ky=_kyy, phic=phic, U=U, mu=mu, valley=valley, alpha=alpha
    )

    _kx = _kx[~np.isnan(es).any(axis=0)]
    es = es[:,~np.isnan(es).any(axis=0)]

    dkmax = (
        integrate.trapz(np.sqrt(-es * np.heaviside(-es, 0) / pot_conc), _kx, axis=1) / np.pi
    )

    n = np.max(np.floor(dkmax + nu))

    iterable = (
        compute_wf(
            kx=_kx,
            es = es[np.argmin(np.abs(dkmax + nu - m)), :],
            pot_conc=pot_conc,
        )(k[0])
        for m in (1 + np.arange(n))
    )
    if n >= 1:
        return np.sum(iterable), n
    if n == 1:
        return iterable, n
    else:
        return np.zeros(k.shape[-1]), n


def x_position(r, initial_index, valley, L, x0=0, color=None, plot=False):
    """
    Compute position where cyclotron orbits end.

    Parameters:
    -----------

    r: nd array
        Coordinates of cyclotron orbit.
    initial_index: int
        Index for intial point of the trajectory.
    valley: int (-1, 1)
        Valley quantum number.
    L: float
        Distance between injector and collector
    x0: float
        Initial x-coordinate.
    color: float
        Transparency of points in scatter plot. Only used if plot=True.
    plot: boolean
        If True, plot trajectory.

    Returns
    -------
    array
        x-coordinate at the end of the trajectory.
    """

    r = np.roll(r, -initial_index, axis=1)
    r[0, :] -= r[0, 0]
    r[1, :] -= r[1, 0]

    # Find maximum y for the trajectory
    ymax_idx = np.argmax(r[1])
    r = r[:, :ymax_idx]

    # Find point where trajectory cross the other side of the system
    filt_y = (r[1] >= 0) * (r[1] <= L)
    if np.where(filt_y)[0].shape[0] == 0:
        return np.array([np.inf, 0])
    r1 = r[:, np.where(filt_y)[0][-1]]
    if r.shape[1] > np.where(filt_y)[0][-1] + 1:
        r2 = r[:, np.where(filt_y)[0][-1] + 1]
    else:
        return np.array([np.inf, 0])
    dr = r2 - r1
    rmed = (r1 + r2) / 2
    collection_angle = np.arctan2(dr[1], dr[0])
    rf = r[:, filt_y]

    if plot and rf.any() and not (initial_index % 5):
        x = rf[0]
        y = rf[1]
        c = [None, "r", "b"]
        color = color * np.ones(len(x))
        plt.scatter(
            x * 1e6,
            y * 1e6,
            c=c[valley],
            lw=3,
            ls="--",
            alpha=np.abs(color),
            zorder=np.max(color),
        )

    if rf.any() and (np.abs(rmed[1] - L) < np.abs(dr[1])):
        return np.array([rmed[0], collection_angle])
    else:
        return np.array([np.inf, 0])


def overlap_gaussian(x0, x1, x2, w):
    """
    Creates a soft-wall collector with finite width.

    Parameters
    ----------
    x0 : float
        Final position of the trajectory.
    x1 : float
        Initial x-coordinate of the collector.
    x2 : float
        Final x-coordinate of the collector.
    w : float
        Broadening of the collector potential.

    Returns
    -------
    float
        Probability of absorption by the collector.

    """
    z1 = (x1 - x0) / (w * np.sqrt(2))
    z2 = (x2 - x0) / (w * np.sqrt(2))
    from scipy.special import erf

    return np.abs(erf(z1) - erf(z2)) / 2


def collimation_calc(
    dataset,
    B,
    offset=0,
    size_quantization=True
):
    """
    Compute collimation "spectra".

    Parameters:
    -----------
    dataset: xarray.Dataset
        Dataset with experimental parameters and Fermi surfaces.
    B: float
        Magnetic field in tesla.
    offset: float
        x-coordinate distance between the center of injector and collector.

    Returns
    -------
    float
        Current absorbed by the collector.
    """

    kF = dataset.kF.data[:: int(np.sign(B))]
    polar_angles = dataset.phis.data[:: int(np.sign(B))]
    phic = 0
    k = kx_ky(k_F=kF, phi=polar_angles, phic=phic)
    r, dr = k_to_r(
        k_F=kF,
        phi=polar_angles,
        phic=phic,
        B=B,
    )
    dIdphi, injection_angle = calculate_injection_prob(dr, cos=False)
    filter_injection = np.multiply(injection_angle >= 0, injection_angle <= np.pi)

    if size_quantization:
        occupation, _ = quantization(
            k=k,
            phic=0,
            U=dataset.U.data,
            valley=dataset.valley.data,
            mu=dataset.mu.data,
            Ec=dataset.Ec.data,
            Ecmax=dataset.Ecmax.data,
            W=dataset.W.data,
            alpha=dataset.alpha.data,
        )
        occupation = occupation[:-1][filter_injection]
    else:
        occupation = 1
    injection_angle = injection_angle[filter_injection]
    dIdphi = dIdphi[filter_injection]

    nbins = 300
    y, x = np.histogram(injection_angle, bins=nbins, weights=dIdphi)
    from scipy import ndimage

    out = ndimage.gaussian_filter(y, sigma=nbins * 5 / 180)
    smearing = interpolate.interp1d(
        x[1:], out / np.max(out), fill_value=0, bounds_error=False
    )
    dIdphi = dIdphi * np.cos(injection_angle - np.pi / 2) * occupation

    # Compute focusing positions
    x_final = partial(
        x_position,
        r=r,
        L=dataset.attrs["L"],
        valley=dataset.valley.data,
        # color=0.0025,
        # plot=True,
    )
    vec_x_focus = np.vectorize(x_final, signature="()->(n)")
    collector_data = vec_x_focus(initial_index=np.where(filter_injection)[0])
    xc, collection_angle = collector_data[:, 0], collector_data[:, 1]
    _xc = xc[(xc > -10 * dataset.W.data) * (xc < 10 * dataset.W.data)]
    xc_dist, _xc = np.histogram(_xc, bins=nbins)
    xc_dist = ndimage.gaussian_filter(
        xc_dist,
        sigma=nbins * dataset.W.data / 4 / np.abs(np.max(_xc) - np.min(_xc)),
    )

    smearing2 = interpolate.interp1d(
        _xc[:-1],
        xc_dist
        * overlap_gaussian(
            _xc[:-1],
            -dataset.W.data / 2 + offset,
            dataset.W.data / 2 + offset,
            dataset.W.data / 4,
        ),
        fill_value=0,
        bounds_error=False,
    )

    xc_dist = smearing2(xc)

    collector_efficiency = smearing(collection_angle) * np.cos(
        collection_angle - np.pi / 2
    )

    if xc.any():
        # Count how many points coincide with the collector region
        Ic = np.sum(
            np.cos(injection_angle - np.pi / 2)
            * occupation
            * xc_dist
            * collector_efficiency
        )

        return Ic
    else:
        return 0
