"""
2022-2024 Sebastian de Bone (QuTech)
https://github.com/sebastiandebone/GHZ_prot_II
_____________________________________________
"""
from termcolor import colored
import matplotlib as mpl
from collections import defaultdict
import sys
sys.path.insert(1, '.')

from oopsc.threshold.sim import update_result_files as update_result_files_sc
from analysis.plot_settings import plot_fonts
import analysis.evaluate_protocols as ep
import analysis.calculate_thresholds as ct
import analysis.plot_functions as pf
import analysis.process_data as prd

ACCURACY = 15


def funtion_that_generates_all_figures(
        plot_type=None, set_name="Set3c", set_for_which_prots_are_found="sIIIc",
        explore_new_protocols=True, update_csv_files=True, update_protocol_information=True, review_old_results=False,
        plot_cot_dependence_protocol="simv2_sVp_4_14_1",
        n=4, max_prots_per_n=30, max_prots_per_n_k=30, p_g=None,
        target_ghz_success_rate=None,       # 92 95 97.5 99 99.5 99.75 99.9
        plot_cot_filtered_out_cots=None,
        plot_style="paper", interactive_plot=False, show_fits=False, output_fits="", zoomed_in_fit=True, zoom_number=3,
        only_show_best_new_protocols_in_plot=True, font_size_big_start=16, font_size_small_start=14,
        figure_scale_factor=1,
        number_scaling_tries=3, best_prots_found=None,
        include_so_stats_dec=False, include_so_stats_stab_fids=False, include_so_stats_weighted_sum=False,
        output_threshold_success_rates=False, nDD=None, decoder="uf",
        folder="results/superoperator/",
        folder_known_prots="Thresholds_known_prots", add_cut_off_to_location=False,
        use_pre_calculated_data=False,
):
    plot_settings = {
        "plot_style": plot_style, "interactive_plot": interactive_plot, "show_fits": show_fits, "output_fits": output_fits,
        "zoomed_in_fit": zoomed_in_fit, "zoom_number": zoom_number,
        "plot_type": plot_type,
        "only_show_best_new_protocols_in_plot": only_show_best_new_protocols_in_plot,
        "font_size_big_start": font_size_big_start, "font_size_small_start": font_size_small_start,
        "figure_scale_factor": figure_scale_factor, "number_scaling_tries": number_scaling_tries,
        "plot_cot_dependence_protocol": plot_cot_dependence_protocol,
        "plot_cot_filtered_out_cots": plot_cot_filtered_out_cots, "best_prots_found": best_prots_found
    }
    dataframe_settings = {
        "include_so_stats_dec": include_so_stats_dec, "include_so_stats_stab_fids": include_so_stats_stab_fids,
        "include_so_stats_weighted_sum": include_so_stats_weighted_sum,
        "output_threshold_success_rates": output_threshold_success_rates,
        "nDD": nDD, "decoder": decoder, "folder": folder,
        "folder_known_protocols": folder_known_prots,   # "in_Set..." is automatically added to the folder name
        "add_cut_off_to_location": add_cut_off_to_location,
    }
    if plot_cot_filtered_out_cots is None:
        plot_settings['plot_cot_filtered_out_cots'] = \
            [0.0164681, 0.0382001, 0.0418481, 0.0783281, 0.0844081, 0.21783452, 0.23440386, 0.26349148]
    if p_g is None:
        p_g = [0, 1]
    plot_settings["font_size_big_start"] = 16 if plot_settings["plot_style"] == "paper" else 15.06
    plot_settings["font_size_small_start"] = 14 if plot_settings["plot_style"] == "paper" else 15.06
    plot_settings["number_scaling_tries"] = 3 if plot_settings["plot_style"] == "paper" else 1

    if plot_settings["interactive_plot"] is True and sys.platform != "linux":
        mpl.use('TkAgg')

    # These are placeholder for calculated data that can be plotted directly:
    if use_pre_calculated_data:
        protocols_to_be_plotted_per_set, cut_off_dict_set_plot = prd.load_pre_calculated_data(plot_type)
    else:
        protocols_to_be_plotted_per_set = None
        cut_off_dict_set_plot = None
    cut_off_information_plot = {0: ('simv3_sIIIc_4_7_1', 0, '\\eta^{\\ast}_\\mathrm{link}=8\\cdot10^2', 'Fig. 7'), 1: ('expedient', 1, '\\eta^{\\ast}_\\mathrm{link}=2\\cdot10^5', 'Fig. 7'), 2: ('refined2', 2, '\\eta^{\\ast}_\\mathrm{link}=2\\cdot10^5', 'Fig. 7'), 3: ('simv2_sVp_4_14_1', 3, 'F_\\mathrm{link}\\approx0.828', 'Fig. 6')}

    protocol_information = ep.find_protocol_information(update_protocol_information)
    # plot_settings["best_prots_found"] = ep.identify_meta_data_for_protocols(protocol_information)

    convert_name = ep.get_alternative_names()

    protocols_calculated_per_set = defaultdict(list)
    best_new_prot_per_set = {}
    cut_off_dict_set = {}
    cut_off_information = {}

    plot_fonts(style=plot_settings["plot_style"])
    i_output_fits = 0

    previous_set_name = None
    first_set = True
    plot_type = plot_settings["plot_type"]
    for explore_new_protocols_yield, set_name_yield, set_for_which_prots_are_found_yield, round_yield, \
        cut_off_dep_plot_data_yield in \
            prd.yield_settings(explore_new_protocols, set_name, set_for_which_prots_are_found, plot_type=plot_type):

        if plot_type in ["link_efficiency", "bell_quality", "bell_succ_prob", "cot_dependence_subplots", "cot_dependence_one_figure"] \
                and (protocols_to_be_plotted_per_set is not None
                     or cut_off_dict_set_plot is not None):
            continue

        if first_set is False:
            print("\n\n\n\n\n\n")
        first_set = False
        if explore_new_protocols_yield:
            print_type_prots = f"new protocols found with set {set_for_which_prots_are_found_yield}"
        else:
            print_type_prots = "known protocols"
        print(f"{colored(f'Calculating {print_type_prots} in set {set_name_yield}:', 'green', 'on_grey', attrs=['bold'])}")

        if round_yield == 2:
            best_new_prot_list = [prot[0] for prot in best_new_prot_per_set.values()]
        else:
            best_new_prot_list = []
        print(colored(f"best_new_prot_list: {best_new_prot_list}", "yellow"))
        print(colored(f"best_new_prot_per_set: {best_new_prot_per_set}", "yellow"))

        if set_name_yield != previous_set_name:
            best_thresholds_per_set = {}

            logical_error_rates_df, re_order_df_columns = prd.define_logical_error_rates_df(**dataframe_settings)
            thresholds_dict = defaultdict(lambda: defaultdict(dict))
            thresholds_fit_dict = defaultdict(lambda: defaultdict(dict))

        previous_set_name = set_name_yield

        folder_so, folder_rates = prd.get_folder_name(
            review_old_results, explore_new_protocols_yield, set_name_yield,
            set_for_which_prots_are_found, set_for_which_prots_are_found_yield,
            **dataframe_settings
        )

        if p_g is None:
            p_g_v = [0, 1]
        elif not isinstance(p_g, list):
            p_g_v = [p_g, p_g]
        else:
            p_g_v = p_g

        for prot_name_short, prot_name, file_name in prd.yield_protocols(
                n, set_name_yield, explore_new_protocols_yield, set_for_which_prots_are_found_yield,
                max_prots_per_n, max_prots_per_n_k, **dataframe_settings
        ):
            # print(prot_name_short, prot_name, file_name, folder + folder_rel)
            if update_csv_files:
                update_result_files_sc(dataframe_settings["folder"] + folder_rates, file_name.split(".csv")[0], print_actions=False)

            file_loc = dataframe_settings["folder"] + folder_rates + file_name
            try:
                logical_error_rates_info = ct.calculate_threshold_information(
                    file_loc, return_information_per_data_point=True
                )
            except FileNotFoundError:
                continue

            p_g_value_in_threshold_info = False
            for cut_off_time in list(logical_error_rates_info.keys()):
                for p_g_value in list(logical_error_rates_info[cut_off_time].keys()):
                    if p_g_value < p_g_v[0] or p_g_value > p_g_v[1]:
                        del logical_error_rates_info[cut_off_time][p_g_value]
                        if not logical_error_rates_info[cut_off_time]:
                            del logical_error_rates_info[cut_off_time]
                    else:
                        p_g_value_in_threshold_info = True
            if p_g_value_in_threshold_info:
                logical_error_rates_df, cdf = prd.fill_logical_error_rates_df_dataframe(
                    logical_error_rates_df, logical_error_rates_info,
                    prot_name, prot_name_short, set_name, folder_so,
                    update_csv_files, **dataframe_settings,
                )
                if cdf is False:
                    break

            thresholds_dict, thresholds_fit_dict = prd.perform_and_store_fits(
                thresholds_dict, thresholds_fit_dict, logical_error_rates_info, prot_name_short,
                logical_error_rates_df, i_output_fits, **plot_settings,
            )

        if not (plot_type in ["link_efficiency", "bell_succ_prob", "bell_quality"]):
            logical_error_rates_df = prd.print_logical_error_rates_df(
                logical_error_rates_df, re_order_df_columns, **dataframe_settings,
            )

        best_thresholds_per_set, best_new_prot_per_set, protocols_calculated_per_set = \
            prd.print_and_save_calculated_thresholds(
                thresholds_fit_dict, protocol_information,
                best_thresholds_per_set, plot_settings["best_prots_found"], best_new_prot_list,
                set_name_yield, explore_new_protocols_yield, round_yield,
                best_new_prot_per_set, protocols_calculated_per_set, print_results=True,
            )

        if plot_type in ["cot_dependence_subplots", "cot_dependence_one_figure"]:
            if cut_off_dep_plot_data_yield is not None:
                protocol = cut_off_dep_plot_data_yield[0]
                subplot_number = cut_off_dep_plot_data_yield[1]
                cut_off_dict_set[subplot_number] = thresholds_fit_dict[protocol]
                cut_off_information[subplot_number] = cut_off_dep_plot_data_yield
            else:
                protocol = plot_settings["plot_cot_dependence_protocol"]  # list(thresholds_dict.keys())[0]
                subplot_number = None
                cut_off_dict_set[0] = thresholds_fit_dict[protocol]

    if cut_off_dict_set_plot is None:
        cut_off_dict_set_plot = cut_off_dict_set
    if plot_type in ["cot_dependence_subplots", "cot_dependence_one_figure"]:
        pf.plot_cut_off_time_plot(cut_off_dict_set_plot, convert_name, cut_off_information_plot, **plot_settings)

    # colors = [f"C{i % 10}" for i in range(len(markers))]
    if plot_type in ["link_efficiency", "bell_succ_prob", "bell_quality"]:
        if plot_settings["only_show_best_new_protocols_in_plot"]:
            if protocols_to_be_plotted_per_set is not None:
                protocols_to_be_plotted_per_set = ep.create_envelope_function_of_new_protocols(protocols_to_be_plotted_per_set)
            else:
                protocols_to_be_plotted_per_set = ep.create_envelope_function_of_new_protocols(protocols_calculated_per_set)
        pf.plot_main_function(protocols_to_be_plotted_per_set, protocols_calculated_per_set, convert_name, **plot_settings)


if __name__ == "__main__":
    funtion_that_generates_all_figures(
        # Options for "plot_type":
        # "link_efficiency", "bell_succ_prob", "bell_quality", "cot_dependence_one_figure", "cot_dependence_subplots",
        plot_type="link_efficiency",
        set_name="Set3c",
        set_for_which_prots_are_found="sIIIc",      # "all_sets"  # "sVr" # "all_sV"
        add_cut_off_to_location=False,
        explore_new_protocols=True,
        update_csv_files=True,
        update_protocol_information=True,
        show_fits=False,
        use_pre_calculated_data=False,
    )
