import numpy as np
import scipy.constants as c
from scipy import optimize, integrate

a = 1.42e-10 * np.sqrt(3)
d = 3.35e-10  # [m] distance between layers
# Hopping constants
gamma_0 = 3.16
gamma_1 = 0.381
gamma_3 = 0.386
# Convert hoppins to velocities
v = np.sqrt(3) * a * gamma_0 / 2 / c.hbar
v_3 = np.sqrt(3) * a * gamma_3 / 2 / c.hbar
# Quantities to compute screened electron density
n_perp = gamma_1**2 / (np.pi * c.hbar**2 * v**2)
epsilon_r = 2
Lambda = d * c.e**2 * n_perp / (2 * gamma_1 * c.epsilon_0)
# Reciprocal lattice vectors
b1 = 2 * np.pi * np.array([1, 1 / np.sqrt(3)])
b2 = 2 * np.pi * np.array([1, -1 / np.sqrt(3)])
# Area of Brillouin zone
bz_area = np.abs(np.cross(b1, b2))


def Gamma(p, phi, phic, U, valley, alpha):
    """
    Computes one term of the dispersion relation.

    Parameters:
    -----------

    phi: float
        Polar angle.
    phic: float
        Crystal orientation angle.
    U: float
        Interlayer imbalance.
    valley: int (-1, 1)
        Valley quantum number.
    alpha: float
        Degree of trigonal warping.

    Returns
    -------
    float
        Gamma term in the dispersion relation.
    """

    Gamma1 = (gamma_1**2 - v_3**2 * p**2) ** 2 / 4
    Gamma2 = v**2 * p**2 * (gamma_1**2 + U**2 + v_3**2 * p**2)
    Gamma3 = (
        alpha * valley * 2 * gamma_1 * v_3 * v**2 * p**3 * np.cos(3 * phi + phic)
    )
    return Gamma1 + Gamma2 + Gamma3


def epsilon(k, phi, phic, U, valley, mu, alpha=1):
    """
    Computes Fermi surface for a given dispersion and chemical potential.

    Parameters:
    -----------

    k: float
        Momentum.
    phi: float
        Polar angle.
    phic: float
        Crystal orientation angle.
    U: float
        Interlayer imbalance.
    valley: int (-1, 1)
        Valley quantum number.
    mu: float
        Chemical potential.
    alpha: float
        Degree of trigonal warping.

    Returns
    -------
    float
        Energy for the corresponding parameters.
    """

    p = c.hbar * k / a
    eps1 = gamma_1**2 / 2 + U**2 / 4
    eps2 = (v**2 + v_3**2 / 2) * p**2
    eps3 = -np.sqrt(Gamma(p, phi, phic, U, valley, alpha))
    return np.sqrt(eps1 + eps2 + eps3) - mu


def compute_mu_and_U(n_exp, D):
    """
    Computes chemical potential, interlayer imbalance, and bottom of the conduction band.

    Parameters:
    -----------

    n_exp: float
        Electron density.
    D: float
        Displacement field.

    Returns
    -------
    mu: float
        Chemical potential.
    U: float
        Layer imbalance.
    Ec: float
        Bottom of conduction band.
    """
    # Compute m from displacement field
    Uext = D * d / epsilon_r
    U = Uext / (
        1
        - Lambda
        / 2
        * np.log(
            np.abs(n_exp) / 2 * n_perp
            + 1 / 2 * np.sqrt((np.abs(n_exp) / n_perp) ** 2 + (Uext / 2 / gamma_1) ** 2)
        )
    )
    # Compute dispersion on a grid
    nk = 2000
    ks = np.linspace(0, 1, nk) * 1.2e-1
    phis = np.linspace(0, 2 * np.pi, nk)
    k, phi = np.meshgrid(ks, phis)
    es = epsilon(k, phi, 0, U, 1, 0)
    # Compute density of states
    y, x = np.histogram(
        es[es < 0.1].flatten(),
        bins=nk,
        weights=(k * np.diff(ks)[0] * np.diff(phis)[0] / a**2)[
            es < 0.1
        ].flatten(),
    )
    gs, gv = 2, 2
    dos = gs * gv * y / np.diff(x)[0] / (bz_area / 2)
    # Compute electron density
    n = integrate.cumulative_trapezoid(y=dos, x=x[:-1])
    # Compute chemical potential
    mu = x[np.argmin(np.abs(n - np.abs(n_exp)))]

    return mu, U, np.min(es)
