"""
2022-2024 Sebastian de Bone (QuTech)
https://github.com/sebastiandebone/GHZ_prot_II
_____________________________________________
"""
import os
import pickle
import time
from termcolor import colored
import dill
import argparse
import sys
sys.path.insert(1, '.')

import GHZ_prot_II.da_protocols as dap
import GHZ_prot_II.simulate_protocol.protocol_recipe as pr
import GHZ_prot_II.simulate_protocol.run_auto_generated_protocol as ragp
import analysis.evaluate_protocols as ep
from analysis.plot_settings import plot_fonts
from generate_data_and_figures.search_GHZ_protocols import LoadFromFile
from utilities.files import get_full_path


def parse_arguments():
    parser = argparse.ArgumentParser(description='Evaluate protocol recipes without GHZ cycle time, and export the best ones.')

    parser.add_argument('-data_file', type=str, default=None,
                        help='Location and filename of GHZ search data structure to be used for the evaluation.')
    parser.add_argument('-set_number', type=str, default="IIIr",
                        help='A string describing the set number for which the search was carried out.')
    parser.add_argument('-n', required=False, type=int, default=4,
                        help='An integer describing the weight of the GHZ states to be extracted from the data file.')
    parser.add_argument('-k_start_eval', required=False, type=int, default=6,
                        help='An integer describing the minimum number of Bell states of protocols to be evaluated.')
    parser.add_argument('-k_max_eval', type=int, default=None,
                        help='An integer describing the maximum number of Bell states of protocols to be evaluated.')
    parser.add_argument('-eval_nmb', type=int, default=1,
                        help='An integer describing the evaluation number.')
    parser.add_argument('-p_g_values', nargs="*", required=False, type=float, default=[0.0001, 0.0005, 0.001, 0.002, 0.003],
                        help='p_g values for which evaluation has to take place.')
    parser.add_argument('-lattices', nargs="*", required=False, type=int, default=[4, 6, 8, 10],
                        help='p_g values for which evaluation has to take place.')
    parser.add_argument('-n_DD', required=False, type=int, default='None',
                        help='An integer describing how many entanglement generation attempts fit in half of a '
                             'dynamical decoupling sequence.')
    parser.add_argument('-alpha', type=float, required=False, default=None,
                        help='A float describing the bright-state population in case the single-click entanglement protocol is used.')
    parser.add_argument('-version_number', required=False, type=str, default=1,
                        help='An integer describing the version number of the GHZ search.')
    parser.add_argument('-multithreading', required=False, action='store_true',
                        help='Specifies if the processes in the protocol are probabilistic.')
    parser.add_argument('-max_nr_threads', type=int, required=False, default=192,
                        help='An integer describing how many threads are maximally used for multithreading.')
    parser.add_argument('-so_iters', type=int, required=False, default=320,
                        help='An integer describing how many iterations are used to calculate the superoperator.')
    parser.add_argument('-sc_iters', type=int, required=False, default=3200,
                        help='An integer describing how many iterations are used for the surface code.')
    parser.add_argument('-show_fits', required=False, action='store_true',
                        help='Specifies if fit results should be plotted.')
    parser.add_argument('-show_fits_high_threshold', required=False, action='store_true',
                        help='Specifies if fit results should be plotted for high thresholds (above '
                             'threshold_show_fits_high_threshold).')
    parser.add_argument('-threshold_show_fits_high_threshold', type=float, required=False, default=0.00068,
                        help='Specifies threshold for which fit results can be plotted.')
    parser.add_argument('-export_protocols', required=False, action='store_true',
                        help='Specifies if protocols should be exported as protocol recipes.')
    parser.add_argument('-max_export_prots_per_class', type=int, required=False, default=2,
                        help='An integer describing how many protocols per class are exported as protocol recipes. '
                             'Here, a class is a selection of binary trees that all have the same tree structure '
                             '(but have different distillation operations).')
    parser.add_argument('-threshold_export_threshold', type=float, required=False, default=0.00009,
                        help='A float describing what the minimum threshold value is for protocols to be exported as '
                             'protocol recipes.')
    parser.add_argument('-skip_threshold_calculations', required=False, action='store_true',
                        help='Specifies if threshold calculations should be skipped.')
    parser.add_argument('-exclude_negative_intervals', required=False, action='store_true',
                        help='Specifies if threshold with confidence interval in negative p_g range should be skipped.')
    parser.add_argument('-interactive_plots', required=False, action='store_true',
                        help='Specifies if plots should be interactive.')

    parser.add_argument('-argument_file',
                        help="Loads values from a file instead of the command line.",
                        type=open,
                        action=LoadFromFile)

    args_vars = vars(parser.parse_args())

    return args_vars


def find_protocol(protocol, protocol_list=None, additional_protocols=None):
    if protocol_list is None:
        main_dir = "results/protocols/ProtocolRecipe"
        protocol_list = []
        for filename in os.listdir(main_dir):
            if "sim" in filename:
                protocol_list.append(filename)
    if additional_protocols is None:
        additional_protocols = []

    found = {'identical': [], 'similar': []}
    for type in found.keys():
        for filename_saved in protocol_list:
            protocol_saved = pickle.load(open("results/protocols/ProtocolRecipe/" + filename_saved, "rb")).protocol
            if ep.check_if_protocols_are_identical(protocol, protocol_saved, skip_similar_prots=type):
                found[type].append(filename_saved)
        for add_prot_name, add_protocol in additional_protocols:
            if ep.check_if_protocols_are_identical(protocol, add_protocol, skip_similar_prots=type):
                found[type].append(add_prot_name)

    return found


def evaluate_and_export_prot(
        data_file, set_number, k_start_eval, k_max_eval, eval_nmb=1,
        p_g_values=None, n_DD=0, version_number=1, n=4, alpha=None, lattices=None,
        multithreading=False, max_nr_threads=192, so_iters=320, sc_iters=3200,
        show_fits=False, show_fits_high_threshold=False, threshold_show_fits_high_threshold=0.00068,
        export_protocols=True, max_export_prots_per_class=2, threshold_export_threshold=0.00009,
        skip_threshold_calculations=True, exclude_negative_intervals=True, interactive_plots=False,
):
    if data_file[-4:] != ".pkl":
        data_file += ".pkl"

    # Check the "results/GHZ_prot_evaluation" folder if a previous evaluation is available
    if "results/GHZ_prot_II" in data_file:
        data_file_eval = ""
        for df_part in data_file.split("/"):
            data_file_eval += df_part + "/" if df_part != "GHZ_prot_II" else "GHZ_prot_evaluation/"
        data_file_eval = data_file_eval[:-1]
    else:
        data_file_eval = data_file
    evaluation_file_name = data_file_eval[:-4] + "_eval_" + str(eval_nmb) + ".pkl"
    if os.path.isfile(evaluation_file_name):
        results_dictionary_full = pickle.load(open(evaluation_file_name, "rb"))
    else:
        results_dictionary_full = {}

    data = pickle.load(open(data_file, "rb"))
    if p_g_values is None:
        p_g_values = [0.0001, 0.0005, 0.001, 0.002, 0.003]
    if lattices is None:
        lattices = [4, 6, 8, 10]
    k_range = range(k_start_eval, k_max_eval + 1)

    plot_fonts(style="paper")

    thresholds_found = {}
    new_protocols = []
    start_time = time.time()
    for k in k_range:
        prot_new_name_prefix = f"dyn_prot_simv{version_number}_s{set_number}_{n}_{k}_"
        try:
            len_data_n_k = len(data[n][k])
        except IndexError:
            continue
        for i in range(min(len_data_n_k, 100)):
            print(f"\nk = {k}, b = {i + 1}, stab_fid = {data[n][k][i].state}.")
            prot_name = prot_new_name_prefix + str(i + 1)
            protocol = dap.identify_protocol(data, n, k, i)
            protocol = dap.protocol_add_meta_data(protocol)
            sim_prots = find_protocol(protocol, additional_protocols=new_protocols)
            protocol_not_yet_used = True
            if len(sim_prots['identical']) == 0:
                print(f"Protocol {prot_name} ({data[n][k][i].state}) is not yet in the list. It has similar protocols: {sim_prots['similar']}.")
            else:
                print(f"Protocol {prot_name} ({data[n][k][i].state}) is already {len(sim_prots['identical'])} times in the list: {sim_prots['identical']}.")
                print(f"It has similar protocols: {sim_prots['similar']}.")
                for prots_sim in sim_prots['identical']:
                    if set_number in prots_sim:
                        protocol_not_yet_used = False
                        break
            new_protocols.append((prot_name, protocol))

            if skip_threshold_calculations:
                thresholds_found[prot_name] = {}
                thresholds_found[prot_name]["threshold"] = None
                protocol_recipe = pr.ProtocolRecipe(protocol)
                thresholds_found[prot_name]["protocol_recipe"] = protocol_recipe
                thresholds_found[prot_name]["similar_prots"] = sim_prots

            elif protocol_not_yet_used:
                thresholds_found[prot_name] = {}
                results_dictionary = results_dictionary_full[prot_name] if \
                    results_dictionary_full is not None and prot_name in results_dictionary_full.keys() else None

                set_name = set_number if set_number[:3] == "Set" else "Set" + set_number
                run_loop = True
                show_fits_inside = show_fits
                while run_loop is True:
                    par, perr, results_dictionary = ragp.initial_threshold_search(prot_name,
                                                                                  pr.ProtocolRecipe(protocol),
                                                                                  set_name,
                                                                                  p_g_values=p_g_values,
                                                                                  lattices=lattices,
                                                                                  so_iters=so_iters, sc_iters=sc_iters,
                                                                                  results_dictionary=results_dictionary,
                                                                                  multithreading=multithreading,
                                                                                  max_nr_threads=max_nr_threads,
                                                                                  show_fits=show_fits_inside,
                                                                                  zoomed_in_fit=False,
                                                                                  alpha=alpha,
                                                                                  n_DD=n_DD,
                                                                                  zoom_number=2,
                                                                                  interactive_plot=interactive_plots)

                    if isinstance(results_dictionary_full, dict):
                        results_dictionary_full[prot_name] = results_dictionary
                    print(f"Found threshold value: {par[0], (par[0] - perr, par[0] + perr)}.")
                    thresholds_found[prot_name]["threshold"] = par[0], (par[0] - perr, par[0] + perr)
                    if show_fits_inside == show_fits:
                        protocol_recipe = pr.ProtocolRecipe(protocol)
                        thresholds_found[prot_name]["protocol_recipe"] = protocol_recipe
                        thresholds_found[prot_name]["similar_prots"] = sim_prots
                    if show_fits_inside is False and show_fits_high_threshold and par[0] > threshold_show_fits_high_threshold \
                            and (exclude_negative_intervals is False
                                 or (exclude_negative_intervals and par[0] - perr > 0)):
                        show_fits_inside = True
                    else:
                        run_loop = False

    get_full_path(evaluation_file_name, strip_filename=True)
    pickle.dump(results_dictionary_full, open(evaluation_file_name, "wb"))
    print(f"Calculation time: {(time.time() - start_time)/60} minutes.")
    print("")
    if not skip_threshold_calculations:
        thresholds_found = {k: v for k, v in sorted(thresholds_found.items(), key=lambda item: item[1]["threshold"][0], reverse=True)}
        print({k: v["threshold"] for k, v in thresholds_found.items()})
    print("")

    exported_protocols = []
    exported_protocols_identical_to_others = {}
    for prot_name in thresholds_found.keys():
        protocol_is_candidate_for_export = False
        if skip_threshold_calculations:
            protocol_is_candidate_for_export = True
        else:
            threshold_value = thresholds_found[prot_name]["threshold"][0]
            threshold_value_left_bound = thresholds_found[prot_name]["threshold"][1][0]
            if threshold_export_threshold is not None and threshold_value > threshold_export_threshold and \
                    (exclude_negative_intervals is False or
                     (exclude_negative_intervals and threshold_value_left_bound > 0)):
                protocol_is_candidate_for_export = True

        if protocol_is_candidate_for_export:
            similar_protocols_to_candidate_protocol = thresholds_found[prot_name]["similar_prots"]["similar"]

            # Check if current protocol is similar to protocols that are already exported
            prot_class = None
            for i_type, export_type in enumerate(exported_protocols):
                for prot_name_exp in export_type["protocols"]:
                    if prot_name_exp in similar_protocols_to_candidate_protocol:
                        prot_class = i_type
                        break
                if prot_class is not None:
                    break
                else:
                    for sim_prot_name in export_type["similar_prots"]:
                        if prot_name == sim_prot_name:
                            prot_class = i_type
                            break

            # We allow 'max_export_prots_per_class' similar protocols to be exported (not more):
            if prot_class is not None and \
                    len(exported_protocols[prot_class]["protocols"]) >= max_export_prots_per_class:
                export_current_protocol = False
            else:
                export_current_protocol = True

            # if not (prot_name in similar_protocols_to_exported_protocols or
            #         len(set(exported_protocols) & set(similar_protocols_to_candidate_protocol)) > 0):
            if export_current_protocol is True:
                if prot_class is None:
                    exported_protocols.append({"protocols": [prot_name],
                                               "similar_prots": similar_protocols_to_candidate_protocol})
                else:
                    exported_protocols[prot_class]["protocols"].append(prot_name)
                    if prot_name in exported_protocols[prot_class]["similar_prots"]:
                        index_prot_name = exported_protocols[prot_class]["similar_prots"].index(prot_name)
                        del exported_protocols[prot_class]["similar_prots"][index_prot_name]
                    exported_protocols[prot_class]["similar_prots"] += similar_protocols_to_candidate_protocol
                    exported_protocols[prot_class]["similar_prots"] = list(set(exported_protocols[prot_class]["similar_prots"]))
                # Export new protocol
                if export_protocols:
                    protocol_recipe = thresholds_found[prot_name]["protocol_recipe"]
                    dill.dump(protocol_recipe, open("results/protocols/ProtocolRecipe/" + prot_name, "wb"))
                    print(colored(f"Protocol {prot_name} is exported", "yellow"))
                else:
                    print(colored(f"Protocol {prot_name} will be exported", "yellow"))
                print(f"Identical protocols: {thresholds_found[prot_name]['similar_prots']['identical']}.")
                print(f"Similar protocols: {thresholds_found[prot_name]['similar_prots']['similar']}.")
                if len(thresholds_found[prot_name]['similar_prots']['identical']) > 0:
                    exported_protocols_identical_to_others[prot_name] = thresholds_found[prot_name]['similar_prots']['identical']
            else:
                if export_protocols:
                    print(colored(f"Protocol {prot_name} is not exported", "red"))
                else:
                    print(colored(f"Protocol {prot_name} will not be exported", "red"))
                print(f"Identical protocols: {thresholds_found[prot_name]['similar_prots']['identical']}.")
                print(f"Similar protocols: {thresholds_found[prot_name]['similar_prots']['similar']}.")
            print("Updated exported protocols:")
            for i_type in range(len(exported_protocols)):
                print(f"Protocols {exported_protocols[i_type]['protocols']} that have similar protocols {exported_protocols[i_type]['similar_prots']}.")

    print("")
    for prot_name in exported_protocols_identical_to_others.keys():
        print(f"Exported protocol {prot_name} has identical protocols: {exported_protocols_identical_to_others[prot_name]}.")

    print("")
    if not skip_threshold_calculations:
        for i_type in range(len(exported_protocols)):
            for prot_name in exported_protocols[i_type]["protocols"]:
                print(f"({prot_name}, None, [{thresholds_found[prot_name]['threshold'][0]}])")


if __name__ == "__main__":
    args_vars = parse_arguments()

    data_file = args_vars['data_file']
    set_number = args_vars['set_number']
    n = args_vars["n"]
    k_start_eval = args_vars["k_start_eval"]
    k_max_eval = args_vars["k_max_eval"]
    eval_nmb = args_vars["eval_nmb"]
    p_g_values = args_vars["p_g_values"]
    lattices = args_vars["lattices"]
    n_DD = args_vars["n_DD"]
    alpha = args_vars["alpha"]
    version_number = args_vars["version_number"]
    multithreading = args_vars["multithreading"]
    max_nr_threads = args_vars["max_nr_threads"]
    so_iters = args_vars["so_iters"]
    sc_iters = args_vars["sc_iters"]
    show_fits = args_vars["show_fits"]
    show_fits_high_threshold = args_vars["show_fits_high_threshold"]
    threshold_show_fits_high_threshold = args_vars["threshold_show_fits_high_threshold"]
    export_protocols = args_vars["export_protocols"]
    max_export_prots_per_class = args_vars["max_export_prots_per_class"]
    threshold_export_threshold = args_vars["threshold_export_threshold"]
    skip_threshold_calculations = args_vars["skip_threshold_calculations"]
    exclude_negative_intervals = args_vars["exclude_negative_intervals"]
    interactive_plots = args_vars["interactive_plots"]

    evaluate_and_export_prot(
        data_file, set_number, k_start_eval, k_max_eval, eval_nmb=eval_nmb,
        p_g_values=p_g_values, n_DD=n_DD, version_number=version_number, n=n, alpha=alpha, lattices=lattices,
        multithreading=multithreading, max_nr_threads=max_nr_threads, so_iters=so_iters, sc_iters=sc_iters,
        show_fits=show_fits, show_fits_high_threshold=show_fits_high_threshold,
        threshold_show_fits_high_threshold=threshold_show_fits_high_threshold,
        export_protocols=export_protocols, max_export_prots_per_class=max_export_prots_per_class,
        threshold_export_threshold=threshold_export_threshold, skip_threshold_calculations=skip_threshold_calculations,
        exclude_negative_intervals=exclude_negative_intervals, interactive_plots=interactive_plots,
    )
