import numpy as np
import kwant
from . import const, functions, shapes


class Device:
    def __init__(self, dimensions, experiment="longitudinal", prefactor=1):
        self.dimensions = dimensions
        self.experiment = experiment
        self.prefactor = prefactor

    def make_device(self):

        self.bulk = self.make_bulk(
            region = 'scattering',
            hopping_t = functions.hopping_t,
            hopping_gamma3 = functions.hopping_gamma3,
            hopping_gamma4 = functions.hopping_gamma4,
        )
        self.bulk_qpcs = self.make_bulk(
            region = 'lead',
            hopping_t = functions.hopping_t,
            hopping_gamma3 = functions.hopping_gamma3,
            hopping_gamma4 = functions.hopping_gamma4,
        )

        self.syst = kwant.Builder()

        # Fill lead Builder
        self.syst.fill(
            self.bulk,
            lambda site: shapes.shape(site, self),
            (0, 0),
            max_sites=float("inf"),
        )

        # Armchair leads

        lead_sym_armchair = self.bulk_sym.subgroup((-2, 1))

        if self.experiment == "longitudinal":
            # Create lead Builder
            lead_armchair_inj = kwant.Builder(lead_sym_armchair)
            # Fill lead Builder
            lead_armchair_inj.fill(
                self.bulk_qpcs,
                lambda site: shapes.leads(site=site, device=self, orientation="armchair", prefactor=1),
                (0, 0),
                max_sites=float("inf"),
            )
            # Attach lead
            self.syst.attach_lead(lead_armchair_inj, add_cells=5)

            # Create lead Builder
            lead_armchair_col = kwant.Builder(lead_sym_armchair)
            # Fill lead Builder
            lead_armchair_col.fill(
                self.bulk_qpcs,
                lambda site: shapes.leads(site=site, device=self, orientation="armchair", prefactor=10),#self.prefactor),
                (0, 0),
                max_sites=float("inf"),
            )
            # Attach lead
            self.syst.attach_lead(lead_armchair_col.reversed(), add_cells=5)

        elif self.experiment == "transverse":
            # Create lead Builder
            lead_armchair_upper = kwant.Builder(lead_sym_armchair)
            # Fill lead Builder
            lead_armchair_upper.fill(
                self.bulk_qpcs,
                lambda site: shapes.lead_transverse(site, self, "upper", 1),
                (0, self.dimensions["L_device"] * 1.5 * 1.25 / 2),
                max_sites=float("inf"),
            )
            # Attach lead
            self.syst.attach_lead(lead_armchair_upper, add_cells=5)

            # Create lead Builder
            lead_armchair_lower = kwant.Builder(lead_sym_armchair)
            # Fill lead Builder
            lead_armchair_lower.fill(
                self.bulk_qpcs,
                lambda site: shapes.lead_transverse(site, self, "lower", self.prefactor),
                (0, -self.dimensions["L_device"] * 1.5 * 1.25 / 2),
                max_sites=float("inf"),
            )
            # Attach lead
            self.syst.attach_lead(lead_armchair_lower, add_cells=5)

        self.make_leads("upper")
        self.make_leads("lower")

        self.fsyst = self.syst.finalized()

    def make_leads(self, which="upper"):

        if which == "upper":
            sign = 1
            hopping_t = functions.hopping_upper_t
            hopping_gamma3 = functions.hopping_upper_gamma3
            hopping_gamma4 = functions.hopping_upper_gamma4
        elif which == "lower":
            sign = -1
            hopping_t = functions.hopping_lower_t
            hopping_gamma3 = functions.hopping_lower_gamma3
            hopping_gamma4 = functions.hopping_lower_gamma4

        bulk_leads = self.make_bulk(
            hopping_t = hopping_t,
            hopping_gamma3 = hopping_gamma3,
            hopping_gamma4 = hopping_gamma4
        )

        # Zigzag leads

        lead_sym_zigzag = self.bulk_sym.subgroup((0, sign * 1))
        lead_sym_zigzag.add_site_family(self.a1, other_vectors=[(-2, 1)])
        lead_sym_zigzag.add_site_family(self.a2, other_vectors=[(-2, 1)])
        lead_sym_zigzag.add_site_family(self.b1, other_vectors=[(-2, 1)])
        lead_sym_zigzag.add_site_family(self.b2, other_vectors=[(-2, 1)])

        # Create lead Builder
        lead_zigzag = kwant.Builder(lead_sym_zigzag)

        # Fill lead Builder
        lead_zigzag.fill(
            bulk_leads,
            lambda site: shapes.leads(site=site, device=self, orientation="zigzag", prefactor=1),
            (0, sign * self.dimensions["L_device"] / 2),
            max_sites=float("inf"),
        )

        self.syst.attach_lead(lead_zigzag, add_cells=0)

    def make_bulk(self, hopping_t, hopping_gamma3, hopping_gamma4, region='lead'):
        """
        Prepare bulk model.

        """
        # Initialize general information

        a, t = 1, 3.16
        gamma1 = 0.38 / t

        sin_30 = 1 / 2
        cos_30 = np.sqrt(3) / 2
        self.lat = kwant.lattice.general(
            [(a * np.sqrt(3) / 2, a * 1 / 2), (0, a * 1)],
            [
                # Lattice A, layer 1
                (0, 0),
                # Lattice B, layer 1
                (a * 1 / (2 * np.sqrt(3)), a * 1 / 2),
                # Lattice A, layer 2
                (-a * 1 / (2 * np.sqrt(3)), a * 1 / 2),
                # Lattice B, layer 2
                # (1e-6, 0),
                (0, 0),
            ],
            norbs=1,
        )
        self.a1, self.b1, self.a2, self.b2 = self.lat.sublattices

        self.bulk_sym = kwant.TranslationalSymmetry(
            self.lat.vec((1, 0)),
            self.lat.vec((0, 1)),
        )

        bulk = kwant.Builder(self.bulk_sym)
        if region == 'scattering':
            onsite_1 = functions.onsite_1
            onsite_2 = functions.onsite_2
        elif region == 'lead':
            onsite_1 = functions.onsite_leads_1
            onsite_2 = functions.onsite_leads_2
        # Add onsite
        bulk[self.a1(0, 0)] = onsite_1
        bulk[self.b1(0, 0)] = onsite_1
        bulk[self.a2(0, 0)] = onsite_2
        bulk[self.b2(0, 0)] = onsite_2
        # Add hoppings
        hoppings1 = (
            ((0, 0), self.a1, self.b1),
            ((0, 1), self.a1, self.b1),
            ((1, 0), self.a1, self.b1),
        )
        hoppings2 = (
            ((0, 0), self.a2, self.b2),
            ((0, -1), self.a2, self.b2),
            ((1, -1), self.a2, self.b2),
        )
        hoppings3 = (
            ((0, 0), self.a2, self.b1),
            ((1, 0), self.a2, self.b1),
            ((1, -1), self.a2, self.b1),
        )
        hoppings4 = (
            ((0, 0), self.b1, self.b2),
            ((0, -1), self.b1, self.b2),
            ((-1, 0), self.b1, self.b2),
            ((0, 0), self.a1, self.a2),
            ((0, 1), self.a1, self.a2),
            ((-1, 1), self.a1, self.a2),
        )

        # Make scattering region

        bulk[[kwant.builder.HoppingKind(*hopping) for hopping in hoppings1]] = hopping_t
        bulk[[kwant.builder.HoppingKind(*hopping) for hopping in hoppings2]] = hopping_t
        bulk[kwant.builder.HoppingKind((0, 0), self.a1, self.b2)] = gamma1
        bulk[[kwant.builder.HoppingKind(*hopping) for hopping in hoppings3]] = hopping_gamma3
        bulk[[kwant.builder.HoppingKind(*hopping) for hopping in hoppings4]] = hopping_gamma4

        return bulk