from typing import Optional, Tuple, List
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit
import pandas as pd

def produce_plot(df, col: str, legend: bool = True, filter_non_zero_sessions:bool=False, filter_non_zero_expired_sessions: bool = False, separate_plots: bool = False) -> Tuple[plt.Figure, plt.Axes]:

    if not separate_plots:
        _fig, _ax = plt.subplots()
        _ax = _ax,
    else:
        _fig, _ax = plt.subplots(1,len(df.index.get_level_values('requested packet rate').unique()))

        if not isinstance(_ax, Tuple) and not isinstance(_ax, np.ndarray):
            _ax = (_ax,)

    for i, value in enumerate(df.index.get_level_values('requested packet rate').unique()):

        try:
            _results = df.loc[df.index.get_level_values('requested packet rate') == value]
            if filter_non_zero_expired_sessions:
                _results = _results.loc[_results["number of expired sessions"] != 0]
            elif filter_non_zero_sessions:
                _results=_results.loc[_results["number of sessions"] != 0]

            _results = _results.groupby("session renewal rate")[
                col].agg(['mean', 'std'])

            _results: pd.DataFrame

            if not separate_plots:
                i = 0

            _ax[i]: plt.Axes

            # _ax[i].errorbar(_results.index, _results['mean'], marker='.', yerr=_results['std'], label=f"$R = {value}$ Hz", capsize=5, markersize=12)
            _ax[i].plot(_results.index, _results["mean"], marker='.', label=(f"fixed, $R = {value}$ Hz" if value != 0 else 'adaptive rate'))
            _ax[i].fill_between(_results.index, list(_results['mean'] - _results['std']), list(_results['mean'] + _results['std']), alpha=0.2)


            _ax[i].set_xlabel("Session renewal rate $\lambda$ $(s^{-1})$")
            _ax[i].set_ylabel(f'{col}')

            if legend:
                _ax[i].legend()

        except:
            pass

    return _fig, _ax if separate_plots else _ax[0]

def produce_plot_set(df, col: str, separate_plots: bool = False):
    return *produce_plot(df, col, separate_plots=separate_plots), *produce_plot(df, col, filter_non_zero_sessions=True, separate_plots=separate_plots), *produce_plot(df, col, filter_non_zero_expired_sessions=True, separate_plots=separate_plots)

def produce_plot_existing_ax(df, ax:plt.Axes, col: str, legend: bool = True, filter_non_zero_sessions:bool=False, filter_non_zero_expired_sessions: bool = False, separate_plots: bool = False) -> None:




    for i, value in enumerate(df.index.get_level_values('requested packet rate').unique()):

        try:
            _results = df.loc[df.index.get_level_values('requested packet rate') == value]
            if filter_non_zero_expired_sessions:
                _results = _results.loc[_results["number of expired sessions"] != 0]
            elif filter_non_zero_sessions:
                _results=_results.loc[_results["number of sessions"] != 0]

            _results = _results.groupby("session renewal rate")[
                col].agg(['mean', 'std'])

            _results: pd.DataFrame

            if not separate_plots:
                i = 0

            ax: plt.Axes

            ax.plot(_results.index, _results["mean"], marker='.', label=f"fixed, $R = {value}$ Hz" if value != 0 else 'adaptive rate')
            ax.fill_between(_results.index, list(_results['mean'] - _results['std']), list(_results['mean'] + _results['std']), alpha=0.2)


            ax.set_xlabel("Session renewal rate $\lambda$ $(s^{-1})$")
            ax.set_ylabel(f'{col}')

            if legend:
                ax.legend()

        except:
            pass