
import os
from math import inf, ceil
from typing import Optional, Dict, Union, Tuple

import numpy as np
import yaml

from Quantum_Network_Architecture.demands import NetworkDemand, TimingConstraints
from Quantum_Network_Architecture.demands.packets import WindowedPacket
from Quantum_Network_Architecture.networks import Network
from Quantum_Network_Architecture.networks.nodes import NodeList
from Quantum_Network_Architecture.sessions import ApplicationSession
from Quantum_Network_Architecture.utils.logging import LogManager

from itertools import combinations, permutations

from qoala.lang.parse import QoalaParser
from qoala.lang.program import QoalaProgram
from qoala.runtime.task import TaskDurationEstimator
from qoala.sim.build import build_network_from_config
from qoala.sim.network import ProcNodeNetwork


def load_program(path: str) -> QoalaProgram:
    with open(os.path.join(os.path.dirname(__file__),path)) as file:
        text = file.read()
    return QoalaParser(text).parse()





class SessionGenerator:

    QKD_ALICE_PROGRAM_100_BITS = load_program("qoala_files/qkd_100_pairs_alice.iqoala")
    QKD_BOB_PROGRAM_100_BITS = load_program("qoala_files/qkd_100_pairs_bob.iqoala")

    BQC_CLIENT_PROGRAM_100_ROUNDS = load_program("qoala_files/bqc_100_rounds_client.iqoala")
    BQC_SERVER_PROGRAM_100_ROUNDS = load_program("qoala_files/bqc_100_rounds_server.iqoala")

    QKD_ALICE_PROGRAM_SINGLE_ROUND = load_program("qoala_files/qkd_alice.iqoala")
    QKD_BOB_PROGRAM_SINGLE_ROUND = load_program("qoala_files/qkd_bob.iqoala")

    BQC_CLIENT_PROGRAM_SINGLE_ROUND = load_program("qoala_files/bqc_1_round_client.iqoala")
    BQC_SERVER_PROGRAM_SINGLE_ROUND = load_program("qoala_files/bqc_1_round_server.iqoala")

    def __init__(self, network: Network, sessions: Dict[str,Dict[str,Dict[str,Union[str,int,float]]]], seed: Optional[int] = None, random_offset: Optional[int] = None, homogeneous_sessions: bool = False, rng: Optional[np.random.Generator] = None) -> None:
        """
        sessions in the format {n1-n2-n3-...:
                                    {identifier:
                                      {application:
                                       packet rate:
                                       session rate:
                                       min instances:
                                       max duration:            n.b. creation_time + max_duration = expiry_time. Duration should be given in seconds
                                       }
                                    }
                                }
        """

        self._network: Network = network
        self._qoala_network = build_network_from_config(network.network_config)


        self._sessions = sessions

        if rng is None:
            self._sampler = np.random.default_rng(seed)
        else:
            self._sampler = rng


        self._session_keys = [f"{d}-{identifier}" for d in sessions.keys() for identifier in sessions[d].keys()]
        self._rates = {f"{d}-{identifier}": sessions[d][identifier]["session rate"] for d in sessions.keys() for identifier in sessions[d].keys()}
        self._applications = {f"{d}-{identifier}": sessions[d][identifier]["application"].upper() for d in sessions.keys() for identifier in sessions[d].keys()}
        self._durations = {f"{d}-{identifier}": int(sessions[d][identifier]["max duration"]) for d in sessions.keys() for identifier in sessions[d].keys()}

        if not all(x in ['QKD', 'BQC'] for x in self._applications.values()):
            raise ValueError("Unknown application submitted")

        self._packet_rates = {f"{d}-{_id}": sessions[d][_id]["packet rate"] for d in sessions.keys() for _id in sessions[d].keys()}
        self._nodes = {f"{d}-{_id}": network.end_nodes.get_flow([int(x) for x in d.split('-')]) for d in sessions.keys() for _id in sessions[d].keys()}
        self._min_instances = {f"{d}-{_id}": sessions[d][_id]["min instances"] for d in sessions.keys() for _id in sessions[d].keys()}


        self._next_session_times = {i: (self._sampler.exponential(1/self._rates[i]) + (self._sampler.uniform(0, random_offset) if random_offset is not None else 0) if self._rates[i] != 0 else 0) for i in self._session_keys}
        self._number_of_sessions = {i: -1 for i in self._session_keys}

        self._homogeneous_sessions = homogeneous_sessions


    def bqc(
            self,
            identifier: str,
            nodes: NodeList,
            packet_rate: int,
            min_instances: int = 150,
            expiry_time: Union[int, float] = 3600e9,
            submission_time: Union[int, float] = 0,
             ) -> ApplicationSession:

        server_minsep= sum(TaskDurationEstimator().lr_duration(ehi=self.qoala_network.nodes[nodes[-1].alias].local_ehi, routine=x) for x in self.BQC_SERVER_PROGRAM_100_ROUNDS.local_routines.values())

        client_minsep= sum(TaskDurationEstimator().lr_duration(ehi=self.qoala_network.nodes[nodes[0].alias].local_ehi, routine=x) for x in self.BQC_CLIENT_PROGRAM_100_ROUNDS.local_routines.values())

        _new_session = ApplicationSession(
            _id=identifier,
            _nodes=nodes,
            _app="BQC",
            _min_number_of_instances=min_instances,
            _expiry_time=expiry_time
        )


        _new_session.demand = NetworkDemand(
            identifier=identifier,
            nodes=nodes,
            app="BQC",
            QoS_options={
                "A": (
                    WindowedPacket(5_000_000,2),
                    packet_rate
                ),
                "min": (
                    WindowedPacket(5_000_000,2),
                    0
                )
            },
            timing_constraints=TimingConstraints(
                minsep=max(client_minsep, server_minsep),
                expiry=expiry_time
            ),
            target_number_of_instances=min_instances,
            submission_time=submission_time
        )

        return _new_session

    def qkd(self,
            identifier: str,
            nodes: NodeList,
            packet_rate: int,
            min_instances: int = 150,
            expiry_time: Union[int,float] = 3600e9,
            submission_time: Union[int,float] = 0,
            ) -> ApplicationSession:

        """
        This method creates the demand and session for a QKD application session.
        :param identifier: unique id for session/demand/task
        :param nodes: Which nodes are involved
        :param packet_rate:
        :param min_instances:
        :param expiry_time:
        :param submission_time:
        :return:
        """

        alice_minsep = sum(TaskDurationEstimator().lr_duration(ehi=self.qoala_network.nodes[nodes[0].alias].local_ehi,routine=x) for x in self.QKD_ALICE_PROGRAM_100_BITS.local_routines.values())

        bob_minsep = sum(TaskDurationEstimator().lr_duration(ehi=self.qoala_network.nodes[nodes[-1].alias].local_ehi,routine=x) for x in self.QKD_BOB_PROGRAM_100_BITS.local_routines.values())

        _new_session = ApplicationSession(
            _id=identifier,
            _nodes=nodes,
            _app="QKD",
            _min_number_of_instances=min_instances,
            _expiry_time=expiry_time
        )

        _new_session.demand = NetworkDemand(
            identifier=identifier,
            nodes=nodes,
            app="QKD",
            QoS_options={
                "A": (
                    WindowedPacket(100_000,1),
                    packet_rate
                ),
            },
            timing_constraints=TimingConstraints(
                minsep=max(alice_minsep,bob_minsep),
                expiry=expiry_time,
            ),
            target_number_of_instances=min_instances,
            submission_time=submission_time
        )

        return _new_session

    @property
    def packet_rates(self):
        return self._packet_rates

    @property
    def qoala_network(self) -> ProcNodeNetwork:
        if self._qoala_network is None:
            raise AttributeError("No Qoala Network currently known. Please use SessionGenerator.rebuild_qoala_network() to rebuild it from known network config")
        else:
            return self._qoala_network

    @qoala_network.setter
    def qoala_network(self, value: Optional[ProcNodeNetwork] = None):
        if not isinstance(value, ProcNodeNetwork):
            raise TypeError(f"Qoala Network must be of type qoala.sim.network.ProcNodeNetwork, not {type(value)}.")
        else:
            self._qoala_network = value

    @qoala_network.deleter
    def qoala_network(self):
        self._qoala_network = None

    def rebuild_qoala_network(self):
        self._qoala_network = build_network_from_config(self._network.network_config)

    @property
    def next_session_arrival_time(self):
        return min(self._next_session_times.values())

    @property
    def all_session_sources(self):
        return self._session_keys

    @property
    def homogeneous_sessions(self):
        return self._homogeneous_sessions

    @property
    def session_arrival_rates(self):
        return self._rates

    @classmethod
    def from_yaml_file(cls, network: Network, file: str, seed: Optional[int] = None, random_offset: Optional[int] = None):

        assert os.path.isfile(file)

        with open(file,'r') as F:
            yaml_data = yaml.safe_load(F)

        return cls(network=network, sessions=yaml_data, seed=seed, random_offset=random_offset)

    @classmethod
    def homogeneous_sessions_p2p(cls, network: Network, application: str, session_rate: float, packet_rate: float, min_instances: int, max_duration: int, seed: Optional[int] = None):

        sessions = {f"{x}-{y}": {
            application: {
                "application": application,
                "session rate": session_rate,
                "packet rate": packet_rate,
                "min instances": min_instances,
                "max duration": max_duration,
            }
            } for x,y in combinations(network.end_nodes.node_ids, 2)
        }

        return cls(network=network, sessions=sessions, seed=seed, homogeneous_sessions=True)

    @classmethod
    def create_homogeneous_sessions(cls, network: Network, clients: list[int], servers: list[int], application: str, session_rate: float, packet_rate: float, min_instances: int, max_duration: int, seed: Optional[int] = None, ):
        sessions = {f"{x}-{y}": {
            application: {
                "application": application,
                "session rate": session_rate,
                "packet rate": packet_rate,
                "min instances": min_instances,
                "max duration": max_duration,
            }
            } for x in clients for y in servers
        }
        LogManager.get_scheduler_logger().info(f"Created sessions: {list(sessions.keys())}")

        return cls(network=network, sessions=sessions, seed=seed, homogeneous_sessions=True)

    @classmethod
    def create_homogeneous_sessions_from_yaml(cls, config_file: str, network: Network, session_rate: float, packet_rate: float, seed: Optional[int] = None,rng: Optional[np.random.Generator] = None):
        with open(config_file, 'r') as F:
            config = yaml.safe_load(F)

        if 'clients' not in config:
            config['clients'] = network.end_nodes.node_ids

        if 'servers' not in config:
            config['servers'] = config['clients']

        sessions = {f"{x}-{y}": {
            config['application']: {
                "application": config['application'],
                "session rate": session_rate,
                "packet rate": packet_rate,
                "min instances": config["minimum instances"],
                "max duration": config["maximum demand duration"],
            }
        } for x in config['clients'] for y in config['servers'] if x < y
        }

        LogManager.get_session_generator_logger().info(f"Created sessions: {list(sessions.keys())}")

        return cls(network=network, sessions=sessions, seed=seed, homogeneous_sessions=True, rng=rng)




    def get_session(self, key:str, time: Union[int, float]) -> ApplicationSession:

        if key not in self._session_keys:
            raise KeyError(f"Unknown session key {key}")
        else:
            if self._applications[key] == 'QKD':

                return self.qkd(identifier=f"{key}-{self._number_of_sessions[key]}", nodes=self._nodes[key],
                                packet_rate=self._packet_rates[key], min_instances=self._min_instances[key],
                                expiry_time=(time + self._durations[key]), submission_time=time)

            elif self._applications[key] == 'BQC':

                return self.bqc(
                    identifier=f"{key}-{self._number_of_sessions[key]}",
                    nodes=self._nodes[key],
                    packet_rate=self._packet_rates[key],
                    min_instances=self._min_instances[key],
                    expiry_time=(time + self._durations[key]),
                    submission_time=time
                )

            else:
                raise NotImplementedError(f"Application {self._applications[key]} not currently supported")


    def next_session(self) -> Tuple[ApplicationSession,float]:

        self._session_keys.sort(key=lambda x: self._next_session_times[x])

        _next_session = self._session_keys[0]
        _time_of_arrival = self.next_session_arrival_time

        self._next_session_times[_next_session] += self._sampler.exponential(1/self._rates[_next_session]) if self._rates[_next_session] != 0 else inf
        self._number_of_sessions[_next_session] += 1


        return self.get_session(_next_session), _time_of_arrival

    def renew_session(self, session_key: str, time: float) -> None:
        pass


class SingleSessionGenerator(SessionGenerator):


    # def __init__(self,network: Network, sessions: Dict[str,Dict[str,Dict[str,Union[int,float]]]], seed: Optional[int] = None):
    #     super().__init__(network, sessions, seed)


    def renew_session(self, session_key: str, time: float) -> None:
        if session_key not in self._session_keys:
            raise KeyError(f"Unknown session key {session_key}")

        self._next_session_times[session_key] = time + self._sampler.exponential(1/self._rates[session_key]) if self._rates[session_key] != 0 else inf

    def next_session(self) -> Tuple[Optional[ApplicationSession],Optional[float]]:
        self._session_keys.sort(key=lambda x: self._next_session_times[x])

        _next_session = self._session_keys[0]
        _time_of_arrival = self._next_session_times[_next_session]

        self._next_session_times[_next_session] = inf
        self._number_of_sessions[_next_session] += 1



        if _time_of_arrival == inf:
            return None, None
        else:
            return self.get_session(_next_session, _time_of_arrival), _time_of_arrival




































