from __future__ import annotations
from math import inf, ceil, floor

from Quantum_Network_Architecture.demands.packets import EntanglementPacket, WindowedPacket
from Quantum_Network_Architecture.demands.conversion import find_min_instances_normal_approximation, calculate_length_of_pga_window, find_min_instances_hoeffding

from Quantum_Network_Architecture.tasks import PacketGenerationTask
from Quantum_Network_Architecture.networks import Network, HomogeneousStarNetwork
from Quantum_Network_Architecture.exceptions import ExpiryError

from typing import Dict, Tuple, Optional
from dataclasses import dataclass
from Quantum_Network_Architecture.networks.nodes import NodeList, Node
import numpy as np

from rich.table import Table
from datetime import timedelta


@dataclass
class TimingConstraints:
    minsep: int = 0
    expiry: int = inf


class NetworkDemand:
    def __init__(
        self,
        identifier: str,
        nodes: NodeList,
        app: str,
        QoS_options: Dict[str, Tuple[EntanglementPacket, float]],
        timing_constraints: TimingConstraints,
        target_number_of_instances: int = 1,
        pgt: Optional[PacketGenerationTask] = None,  # For ease of bookkeeping, keeps the PGT with the associated demand and subsequently session.
        submission_time: float = 0.  # Time at which the demand was submitted in seconds
    ):
        self._id = identifier

        self._application = app

        self._end_nodes = nodes

        n: Node

        self._qos_options = QoS_options  # {QoS_id:(packet, rate)}

        self._timing_constraints = (
            timing_constraints  # TimingConstraints(minsep, expiry_time)
        )
        self._target_number_of_instances = target_number_of_instances

        self._packet_generation_task = pgt

        self._submission_time: float = submission_time

        self._queue_exit_time: Optional[float] = None



    def create_packet_generation_task(self,
                                      network: Network,
                                      time: float | int | None = None,
                                      add_random_period_variance: bool = False,
                                      pga_success_probability: float = 0.2,
                                      ) -> PacketGenerationTask:

        if not 0 < pga_success_probability <= 1:
            raise ValueError(f"pga_success_probability should be in range (0,1]. ({network.pga_success_probability} given).")

        qos_key = "A"

        accepted_qos = self.qos_options[qos_key]

        min_trials = find_min_instances_hoeffding(self.target_number_of_instances,
                                                             pga_success_probability, 1e-5)


        if time < self.timing_constraints.expiry:
            min_rate = min_trials/(self.timing_constraints.expiry - time)
        else:
            raise ExpiryError
        #print(f"I think that this I ({self.identifier}) need a minimum of {min_trials} trials, and it expires at time {timedelta(seconds=self.timing_constraints.expiry)}. I think it is currently {timedelta(seconds=time)}, giving a duration of {(self.timing_constraints.expiry - time)}")

        adjusted_rate = (accepted_qos[1] / pga_success_probability if accepted_qos[
                                                                                 1] != 0 else min_rate) * 1e-9 * network.timeslot_duration  # network.timeslot_duration is an integer number of ns.

        if add_random_period_variance:
            random_adjustment = int(np.random.default_rng().uniform(-3, 3))
            adjusted_rate = (1 / (
                    1 / adjusted_rate + random_adjustment))  # adjusts the period by ±3 to allow conditions to change through the network schedule and allow the dEDF scheduler to schedule more selfs simultaneously.

        if isinstance(accepted_qos[0], WindowedPacket) and isinstance(network, HomogeneousStarNetwork):
            packet: WindowedPacket = accepted_qos[0]
            length_of_pga = calculate_length_of_pga_window(packet.number_of_pairs,
                                                           floor(packet.window / network.timeslot_duration),
                                                           network.end_to_end_attempt_success_probability,
                                                           pga_success_probability)
        else:
            raise NotImplementedError

        factored_minsep = ceil(
            self.timing_constraints.minsep / network.timeslot_duration
        )

        max_rate = 1 / (length_of_pga + factored_minsep)


        self._packet_generation_task = PacketGenerationTask(name=self.identifier, end_nodes=NodeList(self.end_nodes),
                   execution_time=length_of_pga,
                   rate=adjusted_rate if adjusted_rate < max_rate else max_rate,
                   number_to_schedule=min_trials,
                   minsep=factored_minsep, demand_class=self.application,
                   expiry_time=ceil(self.timing_constraints.expiry * 1e9 / network.timeslot_duration),
                   relative_expiry=False, creation_time=time,
                   links=network.get_route(self.end_nodes[0], self.end_nodes[1]), session_ids=None,
                   max_number_of_instances_in_session=self.target_number_of_instances,
                   accepted_qos_option=qos_key,
                    )
        return self._packet_generation_task

    @property
    def queue_exit_time(self):
        return self._queue_exit_time

    @queue_exit_time.setter
    def queue_exit_time(self, value):
        if isinstance(value, float):
            self._queue_exit_time = value

    @property
    def submission_time(self) -> float:
        return self._submission_time


    @property
    def packet_generation_task(self):
        return self._packet_generation_task

    @packet_generation_task.setter
    def packet_generation_task(self, value):
        if not isinstance(value, PacketGenerationTask):
            raise ValueError
        self._packet_generation_task = value
        pass

    @packet_generation_task.deleter
    def packet_generation_task(self):
        self._packet_generation_task = None

    @property
    def target_number_of_instances(self):
        return self._target_number_of_instances

    @property
    def identifier(self):
        return self._id

    @property
    def end_nodes(self):
        return self._end_nodes

    @property
    def qos_options(self):
        return self._qos_options

    @property
    def timing_constraints(self):
        return self._timing_constraints

    @property
    def application(self):
        return self._application

    @property
    def latency_to_service(self) -> Optional[float]:
        return self.packet_generation_task.creation_time - self._submission_time if self.packet_generation_task is not None else None


    def __lt__(self, other):
        if not isinstance(other, NetworkDemand):
            raise ValueError
        return self.timing_constraints.expiry < other.timing_constraints.expiry



    # def __eq__(self, other):
    #     if not isinstance(other, NetworkDemand):
    #         raise ValueError
    #     return self.timing_constraints.expiry == other.timing_constraints.expiry

