from __future__ import annotations

from typing import List, Iterable, TypedDict, Union, Tuple
from qoala.runtime.config import (
    ProcNodeConfig,
    TopologyConfig,
    LatenciesConfig,
    NtfConfig
)

def create_procnode_cfg(name: str, id: int, num_qubits: int,t1:int | None, t2: int|None) -> ProcNodeConfig:
    return ProcNodeConfig(
        node_name=name,
        node_id=id,
        topology= TopologyConfig.uniform_t1t2_qubits_perfect_gates_default_params(num_qubits,t1,t2) if t1 is not None and t2 is not None else TopologyConfig.perfect_config_uniform_default_params(num_qubits),
        latencies=LatenciesConfig(qnos_instr_time=1000),
        ntf=NtfConfig.from_cls_name("GenericNtf"),
    )

class QubitType:

    def __init__(self, available: int, total: int | None = None):
        """

        :param available: Number of available qubits of this type.
        :param total: Optional (None). If None then use available = total.
        """

        self._total: int = available if total is None or (total < available if total is not None else False) else total
        self._quantity: int = available


    @property
    def quantity_available(self):
        return self._quantity

    def use(self, n: int):
        if n > self._quantity:
            self._quantity -= n
            return self._quantity
        else:
            return -1

    def release(self, n: int):
        self._quantity = min(self._total, self._quantity + n)




class Node:

    def __init__(self, node_id: int, node_alias: str | None = None, next_session_id: int = 0, number_of_qubits: int = 3, t1:int | None = None, t2: int | None = None):
        """

        :param node_id:
        :param node_alias:
        :param next_session_id:
        """

        self._id: int = node_id

        if not isinstance(node_id, str) and node_alias is None:
            self._alias: str = str(node_id)
        else:
            self._alias: str = node_alias

        self._next_session_id = next_session_id

        self._registered_processes: TypedDict[str, int] = {}

        self._cfg = create_procnode_cfg(self._alias, self._id, number_of_qubits,t1,t2)

    @property
    def config(self) -> ProcNodeConfig:
        return self._cfg

    @property
    def next_session_id(self):
        self._next_session_id += 1
        return self._next_session_id - 1  # Return the old value not the incremented value


    def register_session(self, session_alias: str) -> int | None:
        if session_alias not in self._registered_processes:
            self._registered_processes[session_alias] = self._next_session_id
            self._next_session_id += 1
            return self._registered_processes[session_alias]
        else:
            return None


    def get_session_id(self, session_alias: str) -> int:
        if session_alias not in self._registered_processes:
            raise KeyError("Unknown process id for this node.")
        else:
            return self._registered_processes[session_alias]


    @property
    def id(self):
        return self._id

    @property
    def alias(self):
        return self._alias

    def __eq__(self, other):
        if isinstance(other, Node):
            return self.id == other.id

        else:
            raise TypeError(f"Comparison not supported between types Node and {type(other)}")

    def __lt__(self, other):
        if isinstance(other, Node):
            return self.id < other.id

        else:
            raise TypeError(f"Comparison not supported between types Node and {type(other)}")

    def __le__(self, other):
        if isinstance(other, Node):
            return self.id <= other.id

        else:
            raise TypeError(f"Comparison not supported between types Node and {type(other)}")

    def __gt__(self, other):
        if isinstance(other, Node):
            return self.id > other.id

        else:
            raise TypeError(f"Comparison not supported between types Node and {type(other)}")

    def __ge__(self, other):
        if isinstance(other, Node):
            return self.id >= other.id

        else:
            raise TypeError(f"Comparison not supported between types Node and {type(other)}")

    def __hash__(self):
        return hash(self._id)





class NodeList(list):
    """
    Convenience class to be able to easily extract all the names or ids of the nodes.
    """

    def __init__(self, __iterable: Iterable[Node] | None = None):
        if __iterable is None:
            __iterable = list()
        super().__init__(__iterable)

    @property
    def node_ids(self):
        n: Node
        return [n.id for n in self]

    @property
    def node_aliases(self):
        n: Node
        return [n.alias for n in self]

    def get_alias_from_id(self, node_id:int) -> str | None:
        for node in self:
            if node.id == node_id:
                return node.alias

        return None

    def get_id_from_alias(self, node_alias: str) -> int | None:
        for node in self:
            if node.alias == node_alias:
                return node.id

        return None

    def get_node(self, query: str | int):
        if not (isinstance(query, str) or isinstance(query, int)):
            raise TypeError(f"query should be int or str, not {type(query)}")
        for node in self:
            if node.alias == query if isinstance(query, str) else node.id == query:
                return node

        raise KeyError

    def get_flow(self, query: Union[List[str, int], Tuple]) -> NodeList:
        return NodeList([self.get_node(x) for x in query])

    def is_node(self, query: str | int)  -> bool:
        if not (isinstance(query, str) or isinstance(query, int)):
            raise TypeError(f"query should be int or str, not {type(query)}")

        else:
            return any(node.alias == query or node.id == query for node in self)

    def is_flow(self, query: Union[List[str, int], Tuple]) -> bool:
        return all(self.is_node(n) for n in query)


if __name__ == '__main__':
    n1 = Node(node_id=2, node_alias='alice')
    n2 = Node(node_id=0)
    n3 = Node(node_id=1)

    nlist = NodeList([n1, n2, n3])

    nlist2 = NodeList(nlist[:1])

    nlist[0].alias = 'Bob'

    pass

