from __future__ import annotations

import itertools
from typing import Iterable, Optional

from Quantum_Network_Architecture.networks.nodes import Node, NodeList
import re
from qoala.runtime.config import (
    ProcNodeNetworkConfig,
    ClassicalConnectionConfig
)

import yaml




class Link:

    def __init__(self, start_node, end_node, suffix:str = "", latency: Optional[int] = None):
        self.full_identifier = f"{start_node}--{end_node}/{end_node}--{start_node}{suffix}"  # Id of the form A--B/B--A. Can specify multiple links between the same nodes with the suffix, e.g. A--B/B--A (1). When looking for link availability, search for a link which contains the string X--Y, hence both directions listed in name, and possibility for multiple links between nodes.
        self._start_node = start_node
        self._end_node = end_node

        self._partial_ids = [f'{start_node}--{end_node}', f'{end_node}--{start_node}']

        self._identifier_no_suffix = f"{start_node}--{end_node}/{end_node}--{start_node}"

        if latency is None:
            self._qoala_config = ClassicalConnectionConfig.from_nodes(start_node, end_node, 0)
        else:
            self._qoala_config = ClassicalConnectionConfig.from_nodes(start_node, end_node, latency)



    @property
    def config(self):
        return self._qoala_config

    @property
    def start_node(self):
        return self._start_node

    @property
    def end_node(self):
        return self._end_node

    @property
    def partial_ids(self):
        return self._partial_ids

    @property
    def id_without_suffix(self):
        return self._identifier_no_suffix

class LinkList(list):

    def __init__(self, __iterable: Iterable[Link] | None = None):
        super().__init__(__iterable)

    @property
    def partial_name_dictionary(self):
        return {p:l.id_without_suffix for l in self for p in l.partial_ids}

    @property
    def link_configs(self):
        return [l.config for l in self]

    def get_full_link_name(self, partial_name: str) -> str:

        if not re.match(r'[0-9]+--[0-9]+', partial_name):
            raise ValueError("Invalid partial link format")

        return self.partial_name_dictionary[partial_name]






class Network:

    def __init__(
            self,
            end_nodes: NodeList,
            jct_nodes: NodeList,
            links: LinkList,
            network_config: ProcNodeNetworkConfig | None = None,
            attempt_duration: int = 1,
    ):

        self._links: LinkList = links
        self._end_nodes: NodeList  = end_nodes
        self._junction_nodes: NodeList = jct_nodes
        self._network_config: ProcNodeNetworkConfig = network_config
        self._network_config.cconns = links.link_configs

        self._length_of_time_slots = attempt_duration

        #TODO: add sanity check for missing nodes

    @property
    def timeslot_duration(self):
        """
        :return: Length of the timeslots in the NS in ns
        """
        return self._length_of_time_slots

    @property
    def links(self):
        return self._links

    @property
    def end_nodes(self) -> NodeList:
        return self._end_nodes

    @property
    def jct_nodes(self):
        return self._junction_nodes

    @property
    def link_ids(self):
        return [l.full_identifier for l in self._links]

    @property
    def network_config(self) -> ProcNodeNetworkConfig:
        return self._network_config


    # def add_link(self, new_link:Link):
    #     self._links.append(new_link)
    #
    #     if new_link.start_node not in self.all_end_nodes:
    #         self.all_end_nodes.append(new_link.start_node)
    #
    #
    #     if new_link.end_node not in self.all_end_nodes:
    #         self.all_end_nodes.append(new_link.end_node)
    # TODO: Add end / jct node checks for this methods

    pass

    def get_route(self, src: Node, dst: Node):
        raise NotImplementedError

    def network_capabilities_update(self, src: Node, dst: NodeList | None = None):
        raise NotImplementedError


class HomogeneousStarNetwork(Network):

    def __init__(
            self,
            number_of_satellites: int,
            central_node_id: int = 255,
            number_of_qubits_per_satellite_node: int = 3,
            prob_max_mixed: float=0.1,
            attempt_success_prob=0.3,
            attempt_duration=100_000.0,
            state_delay=99_999.0,
            t1: Optional[int] = None,
            t2: Optional[int]= None,
            classical_latency: int = 0,
            scheduling_interval: int = 300,
        ):
        """
        TODO
        :param number_of_satellites: Number of satellite nodes in the network
        :param central_node_id: id of the junction node. By default 255
        :param number_of_qubits_per_satellite_node: Number of qubits at each node
        :param prob_max_mixed: Generating Werner mixed states ρ = pI + (1-p)|φ><φ|, probability is given for the end-to-end link
        :param attempt_success_prob: Prob. of each success occurring (E2E)
        :param attempt_duration: Time in ns for each round of entanglement generation between a pair of end nodes.
        :param state_delay: Time after which state is generated in each cycle.
        """


        if central_node_id in range(number_of_satellites):
            central_node_id += number_of_satellites

        _node_names = [
            "alice",
            "bob",
            "charlie",
            "dave",
            "ethel",
            "frank",
            "gareth",
            "harry",
            "inigo",
            "janice",
            "kenneth",
            "louise",
            "mary",
            "norbert",
            "ophelia",
            "patsy",
            "quint",
            "ronald",
            "simone",
            "thomas",
            "viola",
            "wesley",
            "xander",
            "zelda",
        ]

        _end_nodes = NodeList(
            Node(node_id=i, node_alias=(_node_names[i] if number_of_satellites <= len(_node_names) else None),
                 number_of_qubits=number_of_qubits_per_satellite_node, t1=t1, t2=t2) for i in range(number_of_satellites))

        _links = LinkList(Link(sat, central_node_id) for sat in range(number_of_satellites))

        n1:Node
        n2:Node

        if classical_latency is not None:
            classical_link_configs=[ClassicalConnectionConfig.from_nodes(n1.id, n2.id, classical_latency) for n1,n2 in itertools.combinations(_end_nodes,2)]
        else:
            classical_link_configs=[]


        n:Node





        super().__init__(end_nodes=_end_nodes,
                         jct_nodes=NodeList([Node(node_id=central_node_id)]),
                         links=_links,
                         network_config=ProcNodeNetworkConfig.from_nodes_depolarising_noise(
                             nodes=[n.config for n in _end_nodes],
                             prob_max_mixed=prob_max_mixed,
                             attempt_duration=attempt_duration,
                             attempt_success_prob=attempt_success_prob,
                             state_delay=state_delay
                            ),
                         attempt_duration=attempt_duration,
                         )

        self.network_config.cconns = classical_link_configs



        self._end_to_end_rate = attempt_success_prob/attempt_duration
        self._fidelity = 1 - 3/4 * prob_max_mixed

        self._central_jct_id = central_node_id

        self._end_to_end_prob = attempt_success_prob

        self._scheduling_interval: int = scheduling_interval


    @classmethod
    def from_yaml(cls, yaml_file) -> HomogeneousStarNetwork:
        with open(yaml_file, 'r') as F:
            params = yaml.safe_load(F)

        return cls(**params)

    @property
    def scheduling_interval(self):
        return self._scheduling_interval

    @property
    def end_to_end_attempt_success_probability(self):
        return self._end_to_end_prob


    def network_capabilities_update(self, src:Node, dst:Node | NodeList | None = None):

        if dst is None:
            dst = self.end_nodes
        elif isinstance(dst, Node):
            dst = NodeList([dst])

        n: Node

        return {n.id:{'A':(self._fidelity, self._end_to_end_rate)} for n in dst if n != src}  # TODO: Check links exist

    def get_route(self, src:Node, dst:Node):

        return [self.links.partial_name_dictionary[f"{src.id}--{self._central_jct_id}"], self.links.partial_name_dictionary[f"{dst.id}--{self._central_jct_id}"]]





