"""
2022-2024 Sebastian de Bone (QuTech)
https://github.com/sebastiandebone/ghz_prot_II
_____________________________________________
"""
import pandas as pd
from collections import defaultdict
from copy import deepcopy
from oopsc.threshold.plot import mle_binomial
import os


def calculate_threshold_information(file_location, print_information=False, return_information_per_data_point=False):
    data = pd.read_csv(file_location, header=0, float_precision='round_trip')
    # value = 'p_g'
    # data = data[data[value].isin([p_g])]
    if not data.empty:
        # data.set_index(list(data.columns[:47]), inplace=True)
        data.set_index(list(data.columns[:46]), inplace=True)
        index_p_g = data.index.names.index("p_g")
        index_p_m = data.index.names.index("p_m")
        index_cut_off = data.index.names.index("cut_off_time")
        # index_GHZ_success = data.index.names.index("GHZ_success_rate")
        split_sets = defaultdict(list)
        for index_name in data.index:
            lattice_size = index_name[0]
            index_set_label = list(deepcopy(index_name))
            index_set_label[0] = None
            index_set_label[index_p_g] = None
            index_set_label[index_p_m] = None
            # index_set_label[index_GHZ_success] = None
            index_set_label[index_cut_off] = None
            split_sets[tuple(index_set_label)].append(index_name)

        if len(split_sets) > 1:
            print("\n\n\n\nWarning! More than one data set detected in this surface code dataframe.")
            for set in split_sets.keys():
                print(set)
            print("\n\n\n\n")
        set_used = list(split_sets.keys())[0]

        index_by_cut_off = split_index_by_cut_off_time(split_sets[set_used], index_cut_off)
        stripped_data = strip_and_collect_data(data, index_by_cut_off, 0, index_p_g)    #, index_GHZ_success)
        threshold_info = determine_threshold_info_for_each_error_prob(stripped_data)

        if return_information_per_data_point:
            return threshold_info

        thresholds, GHZ_success_rates = calculate_threshold_per_cut_off_time(threshold_info)

        if print_information:
            print("")
            for cut_off_time in thresholds.keys():
                print(f"For cut-off time: {cut_off_time}:")
                print(f"Threshold: {thresholds[cut_off_time]}.")
                print(f"GHZ success rate(s): {GHZ_success_rates[cut_off_time]}.")
                print("")
    else:
        raise FileExistsError(f"File {file_location} could not be found.")

    return thresholds, GHZ_success_rates


def split_index_by_cut_off_time(data, index_cut_off):
    index_by_cut_off = defaultdict(list)
    for index_name in data:
        index_by_cut_off[index_name[index_cut_off]].append(index_name)
    return index_by_cut_off


def strip_and_collect_data(data, index_by_cut_off, index_lattice_size, index_p_g):  #, index_GHZ_success):
    stripped_data = {}
    for cut_off_time, full_data in index_by_cut_off.items():
        stripped_data[cut_off_time] = defaultdict(lambda: dict())
        for index_name in full_data:
            lattice_size = index_name[index_lattice_size]
            p_g = index_name[index_p_g]
            GHZ_success = data.loc[index_name, 'GHZ_success_rate']     #index_name[index_GHZ_success]
            N = data.loc[index_name, 'N']
            success = data.loc[index_name, 'success']
            stripped_data[cut_off_time][p_g][lattice_size] = {"N": N, "success": success,
                                                              "GHZ_success": GHZ_success}
        stripped_data[cut_off_time] = dict(stripped_data[cut_off_time])
    stripped_data = dict(stripped_data)
    return stripped_data


def determine_threshold_info_for_each_error_prob(stripped_data):
    # Determine thresholds for each cut-off time
    threshold_info = defaultdict(lambda: defaultdict(dict))
    for cut_off_time in stripped_data.keys():
        for p_g in stripped_data[cut_off_time].keys():
            data_dict = stripped_data[cut_off_time][p_g]
            success_rates = {L: mle_binomial(data_dict[L]["success"], data_dict[L]["N"]) for L in
                             stripped_data[cut_off_time][p_g].keys()}
            # success_rates = dict(sorted(success_rates.items(), key=lambda item: item[1][0], reverse=True))
            # sorted_L = sorted(list(success_rates.keys()), reverse=True)
            # threshold_info[cut_off_time][p_g]["Below_threshold"] = True \
            #     if list(success_rates.keys()) == sorted_L else False

            success_rates = dict(sorted(success_rates.items(), key=lambda item: item[0], reverse=True))
            distances = list(success_rates.keys())
            might_be_below_threshold = True
            below_threshold = True
            might_be_above_threshold = False
            for i_L in range(1, len(distances)):
                prev_point = success_rates[distances[i_L - 1]]
                cur_point = success_rates[distances[i_L]]
                if prev_point[0] + prev_point[1] < cur_point[0] - cur_point[1]:
                    might_be_below_threshold = False
                if prev_point[0] < cur_point[0]:
                    below_threshold = False
                if prev_point[0] - prev_point[1] <= cur_point[0] + cur_point[1]:
                    might_be_above_threshold = True
            if (below_threshold is True or might_be_below_threshold is True) \
                    and list(success_rates.values())[0][0] < 0.1:
                below_threshold = False                 # At very low success rates, we manually set these parameters
                might_be_below_threshold = False        # to "False".
            threshold_info[cut_off_time][p_g]["Success_rates"] = dict(
                sorted(success_rates.items(), key=lambda item: item[1][0], reverse=True))
            threshold_info[cut_off_time][p_g]["Below_threshold"] = below_threshold
            threshold_info[cut_off_time][p_g]["Might_be_below_threshold"] = might_be_below_threshold
            threshold_info[cut_off_time][p_g]["Might_be_above_threshold"] = might_be_above_threshold
            threshold_info[cut_off_time][p_g]["GHZ_success"] = data_dict[distances[0]]["GHZ_success"]
            iter_p_dist = {L: int(data_dict[L]["N"]) for L in threshold_info[cut_off_time][p_g]["Success_rates"].keys()}
            if list(iter_p_dist.values()).count(list(iter_p_dist.values())[0]) == len(iter_p_dist):
                threshold_info[cut_off_time][p_g]["Number_of_iterations"] = \
                    int(data_dict[list(threshold_info[cut_off_time][p_g]["Success_rates"].keys())[0]]["N"])
            else:
                threshold_info[cut_off_time][p_g]["Number_of_iterations"] = \
                    {L: int(data_dict[L]["N"]) for L in threshold_info[cut_off_time][p_g]["Success_rates"].keys()}

        threshold_info[cut_off_time] = dict(threshold_info[cut_off_time])
    threshold_info = dict(threshold_info)
    return threshold_info


def calculate_threshold_per_cut_off_time(threshold_info):
    thresholds = {}
    GHZ_success_rates = {}
    for cut_off_time in threshold_info.keys():
        p_g_values = list(threshold_info[cut_off_time].keys())
        for p_g in [1e12, 1e6]:
            if p_g in p_g_values:
                GHZ_success_rates[cut_off_time] = {p_g: threshold_info[cut_off_time][p_g]["GHZ_success"]}
                break
        for p_g in [1e-6, 1e-12]:
            if p_g in p_g_values:
                p_g_values.remove(p_g)
        if len(p_g_values) >= 3:
            p_g_values_asc = sorted(list(threshold_info[cut_off_time].keys()))
            p_g_values_des = sorted(list(threshold_info[cut_off_time].keys()), reverse=True)
            if cut_off_time in GHZ_success_rates.keys():
                GHZ_success_rates[cut_off_time][p_g_values_asc[0]] = threshold_info[cut_off_time][p_g_values_asc[0]]["GHZ_success"]
            else:
                GHZ_success_rates[cut_off_time] = {p_g_values_asc[0]: threshold_info[cut_off_time][p_g_values_asc[0]]["GHZ_success"]}
            lower_bound_threshold = 0
            upper_bound_threshold = 1
            threshold = 0
            for p_g in p_g_values_asc:
                if threshold_info[cut_off_time][p_g]["Might_be_above_threshold"] is False:
                    lower_bound_threshold = p_g
                if threshold_info[cut_off_time][p_g]["Below_threshold"] is True:
                    threshold = p_g
            for p_g in p_g_values_des:
                if threshold_info[cut_off_time][p_g]["Might_be_below_threshold"] is False:
                    upper_bound_threshold = p_g
            thresholds[cut_off_time] = [threshold, (lower_bound_threshold, upper_bound_threshold)]
            if threshold > 0:
                GHZ_success_rates[cut_off_time][threshold] = threshold_info[cut_off_time][threshold]["GHZ_success"]
    return thresholds, GHZ_success_rates


def update_list_of_calculated_thresholds_per_set(set_name_new, prot_name_new):
    if os.path.isfile("results/calculated_thresholds_per_set.csv") is False:
        calculated_thresholds_df = pd.DataFrame(columns=["set_name", "protocols"])
    else:
        calculated_thresholds_df = pd.read_csv("results/calculated_thresholds_per_set.csv")
    calculated_thresholds_df.set_index("set_name", inplace=True)
    calculated_thresholds = {}
    for set_name in calculated_thresholds_df.index:
        if isinstance(calculated_thresholds_df.loc[set_name, "protocols"], str):
            calculated_thresholds[set_name] = calculated_thresholds_df.loc[set_name, "protocols"].split(" ")
        else:
            calculated_thresholds[set_name] = []
    if set_name_new in calculated_thresholds.keys():
        if prot_name_new not in calculated_thresholds[set_name_new]:
            calculated_thresholds[set_name_new].append(prot_name_new)
            calculated_thresholds_df.loc[set_name_new, "protocols"] = " ".join(calculated_thresholds[set_name_new])
            calculated_thresholds_df.to_csv("results/calculated_thresholds_per_set.csv")
    else:
        calculated_thresholds[set_name_new] = [prot_name_new]
        calculated_thresholds_df.loc[set_name_new, "protocols"] = " ".join(calculated_thresholds[set_name_new])
        calculated_thresholds_df.to_csv("results/calculated_thresholds_per_set.csv")

    return calculated_thresholds
