"""
2021-2024 Sébastian de Bone (QuTech)
https://github.com/sebastiandebone/ghz_prot_II/
"""
import GHZ_prot_II.da_protocols as dap
import GHZ_prot_II.operations as op
import itertools
from math import factorial
import random
from copy import deepcopy
import pandas as pd
import os
import pickle


def id_nrs_elementary_links(root):
    """
    Function that prints out a list of elementary links that need to be created between physical network nodes. They
    are ordered in such a way that larger states can be made before smaller states and states on the left are created
    before states on the right.

    Parameters
    ----------
    root : binary tree with purification and distillation operations.

    Returns
    -------
    id_nrs : list of binary tree network identification numbers plus the two physical network numbers where the
             concerning elementary link is created in the physical network
    """
    node = root
    node_stack = []
    id_nrs = []
    while node or node_stack:
        while node:
            if node.left == None:
                id_nrs.append([node.id, node.node_nrs])
                node_stack.append(node.right)
                node = node.left
            else:
                if node.right.value.n > node.left.value.n:
                    node_stack.append(node.left)
                    node = node.right
                else:
                    node_stack.append(node.right)
                    node = node.left
        node = node_stack.pop()
    return id_nrs


def match_id_elem_n_4(id_elem):
    """
    Function that couples the creation of elementary links in a certain place in the physical network to the creation
    of elementary links in another location of the physical network. For example, a link between A and B could be
    created simultaneously with a link between C and D. If the list contains a "None", that means that the concerning
    elementary link could not be matched.

    Parameters
    ----------
    id_elem : list of binary tree network identification numbers plus the two physical network numbers where the
              concerning elementary link is created in the physical network (list produced by id_nrs_elementary_links)

    Returns
    -------
    matches : list of integers and None values indicating what elementary links can be created simultaneously.
    """
    k = len(id_elem)
    matches = [None] * k
    exclude_j = [False] * k
    for i in range(k):
        if exclude_j[i] is False:
            nodes_i = id_elem[i][1]
            for j in range(i + 1, k):
                if exclude_j[j] is False:
                    nodes_j = id_elem[j][1]
                    if len(list(set(nodes_i) & set(nodes_j))) == 0:
                        matches[i] = j
                        matches[j] = i
                        exclude_j[j] = True
                        break
    return matches


def is_prot_symmetric(protocol):
    """
    Function that evaluates if a protocol is symmetric (meaning at all times two links can be created at both
    sides of the network).

    Parameters
    ----------
    protocol : binary tree with purification and distillation operations.

    Returns
    -------
    symmetric : Boolean
        Boolean function describing if a protocol is symmetric (True) or not symmetric (False)
    """
    protocol = dap.protocol_add_meta_data(protocol)
    id_elem = id_nrs_elementary_links(protocol)
    id_linked = match_id_elem_n_4(id_elem)
    found_a_None = False
    for link in range(len(id_linked)):
        if id_linked[link] is None:
            found_a_None = True
            break
    return not found_a_None


def protocol_to_list_of_nodes_based_on_id(protocol):
    """
    Function that writes protocols as a list based on the id numbers of the links.

    Parameters
    ----------
    protocol : binary tree with purification and distillation operations

    Returns
    -------
    list_of_nodes : list of links ordered after the id of the links. Each element of the list contains four elements
                    itself: node.value of the links, the id of the child on the left, the id of the child on the right,
                    the id of its parent, and a list with network node numbers on which the link is created
    """
    list_of_nodes = []
    if protocol == None:
        return
    myStack = []
    node = protocol
    parent_id = None
    while node or myStack:
        while node:
            if node.left == None:
                list_of_nodes.append([node.value, None, None, parent_id, node.node_nrs])
            else:
                list_of_nodes.append([node.value, node.left.id, node.right.id, parent_id, node.node_nrs])
            myStack.append(node)
            parent_id = node.id
            node = node.left
        node = myStack.pop()
        parent_id = node.id
        node = node.right
    return list_of_nodes


def list_of_node_ids_info(protocol):
    """
    Placeholder.

    Parameters
    ----------
    protocol : binary tree with purification and distillation operations

    Returns
    -------
    list_of_node_ids : Placeholder.
    """
    list_of_node_ids = []
    if protocol == None:
        return
    myStack = []
    node = protocol
    while node or myStack:
        while node:
            node_information = {}
            index_sort_node_nrs = [i[0] for i in sorted(enumerate(node.node_nrs), key=lambda x:x[1])]
            node_information["node_numbers"] = [node.node_nrs[i] for i in index_sort_node_nrs]
            node_information["children"] = (node.left.id, node.right.id) if node.left is not None else (None, None)
            if node.value.p_or_f == 0:
                node_information["operation_type"] = "DISTILL"
                if index_sort_node_nrs != [*range(len(node.node_nrs))]:
                    operator_in_bin = op.dec2signs(node.value.dec, len(node.node_nrs))
                    qubits_inv = [0] * len(operator_in_bin)
                    for i in range(1, len(operator_in_bin)):
                        if operator_in_bin[i] == 1:
                            qubits_inv[i - 1] = (qubits_inv[i - 1] + 1) % 2
                            qubits_inv[i] = (qubits_inv[i] + 1) % 2

                    new_qubits_inv = [qubits_inv[i] for i in index_sort_node_nrs]
                    new_operator_in_bin = [0] * len(operator_in_bin)
                    new_operator_in_bin[0] = operator_in_bin[0]
                    ungoing_chain = False
                    for i in range(len(operator_in_bin)):
                        if new_qubits_inv[i] == 1 and ungoing_chain is False:
                            new_operator_in_bin[i + 1] = 1
                            ungoing_chain = True
                        elif ungoing_chain is True:
                            new_operator_in_bin[i] = 1
                            if new_qubits_inv[i] == 1:
                                ungoing_chain = False
                    new_operator_in_dec = 0
                    for digit in new_operator_in_bin:
                        new_operator_in_dec = new_operator_in_dec * 2 + int(digit)
                    node_information["operator"] = new_operator_in_dec
                else:
                    node_information["operator"] = node.value.dec
            elif node.value.p_or_f == 1:
                node_information["operation_type"] = "FUSE"
                node_information["operator"] = list(set(node.left.node_nrs) & set(node.right.node_nrs))[0]
            else:
                node_information["operation_type"] = "CREATE_LINK"
                node_information["operator"] = None
            list_of_node_ids.append(node_information)
            myStack.append(node)
            node = node.left
        node = myStack.pop()
        node = node.right
    return list_of_node_ids


def check_if_protocols_are_identical(protocol1, protocol_list, skip_similar_prots='identical'):
    if protocol1 is None or protocol_list is None or skip_similar_prots not in ['identical', 'similar']:
        return False
    if not isinstance(protocol_list, list):
        protocol_list = [protocol_list]
    n = len(protocol1.node_nrs)
    prots_with_same_n = False
    for protocol in protocol_list:
        n2 = len(protocol.node_nrs)
        if n2 == n:
            prots_with_same_n = True
    if prots_with_same_n is False:
        return False

    info_list_1 = list_of_node_ids_info(protocol1)
    info_list_prots = []
    prots_with_same_ops_children = False
    for protocol in protocol_list:
        info_list_prot = list_of_node_ids_info(protocol)
        info_list_prots.append(info_list_prot)
        prot_has_same_ops_children = True
        for node_id, node in enumerate(info_list_1):
            if node["operation_type"] != info_list_prot[node_id]["operation_type"] \
                    or node["children"] != info_list_prot[node_id]["children"]:
                prot_has_same_ops_children = False
                break
        if prot_has_same_ops_children is True:
            prots_with_same_ops_children = True
    if prots_with_same_ops_children is False:
        return False
    elif skip_similar_prots == "similar":
        # In this mode we even do not want protocols that have the same structure as one of the protocols in the list
        # (if we remove this and have the part below, we do allow for protocols with the same structure that have
        # different operators and types).
        return True

    for new_node_nr_order in itertools.permutations([*range(n)]):
        new_list = deepcopy(info_list_1)
        for node_id, node in enumerate(new_list):
            for pos, node_nr in enumerate(node["node_numbers"]):
                node["node_numbers"][pos] = new_node_nr_order[node_nr]
            index_sort_node_nrs = [i[0] for i in sorted(enumerate(node["node_numbers"]), key=lambda x: x[1])]
            new_node_numbers = [node["node_numbers"][i] for i in index_sort_node_nrs]
            node["node_numbers"] = new_node_numbers
            if node["operation_type"] == "FUSE":
                node["operator"] = new_node_nr_order[node["operator"]]
            elif node["operation_type"] == "DISTILL":
                operator_in_bin = op.dec2signs(node["operator"], len(node["node_numbers"]))
                qubits_inv = [0] * len(operator_in_bin)
                for i in range(1, len(operator_in_bin)):
                    if operator_in_bin[i] == 1:
                        qubits_inv[i - 1] = (qubits_inv[i - 1] + 1) % 2
                        qubits_inv[i] = (qubits_inv[i] + 1) % 2
                new_qubits_inv = [qubits_inv[i] for i in index_sort_node_nrs]
                new_operator_in_bin = [0] * len(operator_in_bin)
                new_operator_in_bin[0] = operator_in_bin[0]
                ungoing_chain = False
                for i in range(len(operator_in_bin)):
                    if new_qubits_inv[i] == 1 and ungoing_chain is False:
                        new_operator_in_bin[i + 1] = 1
                        ungoing_chain = True
                    elif ungoing_chain is True:
                        new_operator_in_bin[i] = 1
                        if new_qubits_inv[i] == 1:
                            ungoing_chain = False
                new_operator_in_dec = 0
                for digit in new_operator_in_bin:
                    new_operator_in_dec = new_operator_in_dec * 2 + int(digit)
                node["operator"] = new_operator_in_dec

        for info_list_prot in info_list_prots:
            equal_protocols = True
            for node_id, node in enumerate(new_list):
                for info in node:
                    if node[info] != info_list_prot[node_id][info]:
                        equal_protocols = False
                        break
            if equal_protocols is True:
                del new_list
                return True
        del new_list

    return False


def other_versions_of_binary_tree_protocol(protocol, number_of_orderings_checked=1000):
    protocol = dap.protocol_add_meta_data(protocol)
    id_elem = id_nrs_elementary_links(protocol)

    number_of_orderings = factorial(len(id_elem))
    seeds = [*range(number_of_orderings_checked)]
    random_shuffling = True if number_of_orderings > number_of_orderings_checked else False

    perms = itertools.permutations(id_elem)

    id_created_matched_checked = []
    saved_id_elem = []
    i_perm = 0
    iterations = 0
    for perm in perms:
        if i_perm >= number_of_orderings_checked:
            break
        if random_shuffling:
            if i_perm != 0:
                random.shuffle(id_elem)
        else:
            id_elem = list(perm)

        id_matched = match_id_elem_n_4(id_elem)

        id_created_matched = []
        for i_id, id in enumerate(id_matched):
            if id is not None:
                matched_links = sorted([id_elem[i_id][0], id_elem[id][0]])
            else:
                matched_links = [id_elem[i_id][0]]
            if matched_links not in id_created_matched:
                id_created_matched.append(matched_links)

        if id_created_matched not in id_created_matched_checked:
            id_created_matched_checked.append(id_created_matched)
            saved_id_elem.append(deepcopy(id_elem))
            i_perm += 1
        elif random_shuffling:
            i_perm += 1
        iterations += 1

    return saved_id_elem


def find_protocol_information(update_protocol_information=True):
    if update_protocol_information or os.path.isfile('results/protocols/protocol_information.csv') is False:
        protocol_recipe_folder = "results/protocols/ProtocolRecipeExport"
        protocol_recipe_folder_org = "results/protocols/ProtocolRecipe"
        files = os.listdir(protocol_recipe_folder)
        all_protocols_used = [f for f in files if f[:8] == "dyn_prot"]
        protocol_information = {}
        for protocol_name in all_protocols_used:
            try:
                prot_rec = pickle.load(open(protocol_recipe_folder_org + protocol_name, "rb"))
                protocol_full = prot_rec.protocol
                protocol_information[protocol_name[9:]] = {"identical": [], "similar": [], "q": None}
                protocol_information[protocol_name[9:]]["q"] = [len(node) for node in prot_rec.qubit_memory]
                for other_protocol_name in all_protocols_used:
                    if other_protocol_name == protocol_name:
                        continue
                    try:
                        other_protocol_full = pickle.load(open(protocol_recipe_folder_org + other_protocol_name, "rb")).protocol
                        for keyword in ["identical", "similar"]:
                            if check_if_protocols_are_identical(protocol_full, other_protocol_full,
                                                                skip_similar_prots=keyword):
                                protocol_information[protocol_name[9:]][keyword].append(other_protocol_name[9:])
                    except FileNotFoundError:
                        continue
                        # print(f"File {other_protocol_name} is not in the folder with the original protocols.")
            except FileNotFoundError:
                continue
                # print(f"File {protocol_name} is not in the folder with the original protocols.")

        protocol_information_df = pd.DataFrame(columns=["protocol", "identical", "similar", "q"])
        protocol_information_df.set_index("protocol", inplace=True)
        for protocol_name in protocol_information.keys():
            for keyword in ["identical", "similar", "q"]:
                protocol_information_df.loc[protocol_name, keyword] = " ".\
                    join(str(item) for item in protocol_information[protocol_name][keyword])
        protocol_information_df.to_csv("results/protocols/protocol_information.csv")

    else:
        protocol_information_df = pd.read_csv('results/protocols/protocol_information.csv')
        protocol_information_df.set_index("protocol", inplace=True)
        protocol_information = {}
        for protocol_name in protocol_information_df.index:
            protocol_information[protocol_name] = {"identical": None, "similar": None, "q": None}
            for keyword in ["identical", "similar"]:
                if isinstance(protocol_information_df.loc[protocol_name, keyword], str):
                    protocol_information[protocol_name][keyword] = \
                        protocol_information_df.loc[protocol_name, keyword].split(" ")
                else:
                    protocol_information[protocol_name][keyword] = []
            protocol_information[protocol_name]["q"] = [int(item) for item in
                                                        protocol_information_df.loc[protocol_name, "q"].split(" ")]

    # for protocol_name in protocol_information.keys():
    #     print(f"{protocol_name}: {max(protocol_information[protocol_name]['q'])}.")
    #     print(f"{protocol_name}: {protocol_information[protocol_name]}.")

    return protocol_information


def create_envelope_function_of_new_protocols(protocols_to_be_plotted_per_set):
    protocols_to_be_plotted_per_set_new = {}
    for set in protocols_to_be_plotted_per_set.keys():
        if set != "Set3p":
            protocols_to_be_plotted_per_set_new[set] = []
            old_prots = []
            best_new_prot = None
            best_new_prot_fid = 0
            best_old_prot = None
            best_old_prot_fid = 0
            for protocol in protocols_to_be_plotted_per_set[set]:
                if protocol[0][:3] == "sim":
                    if protocol[2][0] > best_new_prot_fid:
                        best_new_prot = ('Best protocol GHZ optimization', None, protocol[2], protocol[0])
                        best_new_prot_fid = protocol[2][0]
                elif protocol[0] not in ["plain", "minimum4x_22", "minimum4x_40"]:
                    best_old_prot = True
                    old_prots.append(protocol)
                # else:
                #     if protocol[2][0] > best_old_prot_fid:
                #         best_old_prot = ('Best previously known protocol', None, protocol[2])
                #         best_old_prot_fid = protocol[2][0]
            if best_new_prot is not None:
                protocols_to_be_plotted_per_set_new[set].append(best_new_prot)
            if best_old_prot is not None:
                for old_prot in old_prots:
                    protocols_to_be_plotted_per_set_new[set].append(old_prot)
    for set in protocols_to_be_plotted_per_set_new.keys():
        print(set)
        for i_prot, protocol in enumerate(protocols_to_be_plotted_per_set_new[set]):
            if protocol[0] == 'Best protocol GHZ optimization':
                print(protocol[3])
                protocols_to_be_plotted_per_set_new[set][i_prot] = (protocol[0], protocol[1], protocol[2])
            else:
                print(protocol[0])
        print("")
    return protocols_to_be_plotted_per_set_new


def identify_meta_data_for_protocols(protocol_information):
    best_prots_found_per_set = {"Set3r": ["simv1_sIIIr_4_6_1", "simv1_sIIIr_4_6_13", "simv1_sIIIr_4_6_2",
                                          "simv1_sIIIr_4_6_12"],
                                "Set3d": ["simv4_sIIId_4_6_17", "simv4_sIIId_4_6_9", "simv4_sIIId_4_6_1",
                                          "simv4_sIIId_4_6_5"],
                                "Set3q": ["simv1_sIIIq_4_6_9", "simv1_sIIIq_4_6_10", "simv1_sIIIq_4_6_1"],
                                "Set3c": ["simv3_sIIIc_4_7_1", "simv3_sIIIc_4_7_6", "simv4_sIIIc_4_7_14"],
                                "Set3e": ["simv4_sIIIe_4_7_1", "simv4_sIIIe_4_7_3"],
                                "Set3f": ["simv5_sIIIf_4_11_14", "simv5_sIIIf_4_11_2"],
                                "Set5r": ["simv1_sVr_4_12_9", "simv1_sVr_4_12_1", "simv1_sVr_4_13_9"],
                                "Set5q": ["simv1_sVq_4_12_9", "simv1_sVq_4_11_1", "simv1_sVq_4_12_1"],
                                "Set5p": ["simv1_sVp_4_11_18", "simv1_sVp_4_10_2", "simv2_sVp_4_12_1",
                                          "simv2_sVp_4_14_1", "simv2_sVp_4_15_1"],
                                "Set5a": ["simv1_sVa_4_10_10", "simv1_sVa_4_11_5", "simv1_sVa_4_9_17",
                                          "simv1_sVa_4_9_16"],
                                "Set5b": ["simv1_sVb_4_9_1", "simv1_sVb_4_10_1"],
                                "Set5c": ["simv1_sVc_4_11_5", "simv1_sVc_4_11_8", "simv3_sVc_4_11_7",
                                          "simv3_sVc_4_10_4"],
                                "Set5d": ["simv1_sVd_4_11_1", "simv1_sVd_4_11_2"],
                                "Set5e": ["simv2_sVe_4_7_7", "simv2_sVe_4_7_9", "simv2_sVe_4_7_4"],
                                "Set6k": ["simv1_sVIk_4_6_9", "simv1_sVIk_4_6_1"],
                                "Set6g": ["simv1_sVIg_4_7_9"],
                                "Set6m": ["simv2_sVIm_4_7_6", "simv2_sVIm_4_7_9"],
                                "Set6h": ["simv1_sVIh_4_7_2", "simv2_sVIh_4_7_3", "simv1_sVIh_4_7_1",
                                          "simv2_sVIh_4_7_2"]}

    best_prots_found = []
    best_prots_found_Set3and5 = []
    for set_name_key in best_prots_found_per_set.keys():
        if set_name_key[:4] in ["Set3", "Set5", "Set6"]:
            for protocol_name in best_prots_found_per_set[set_name_key]:
                best_prots_found.append(protocol_name)
                if set_name_key[:4] in ["Set3", "Set5"]:
                    best_prots_found_Set3and5.append(protocol_name)
    print(best_prots_found)
    print("")

    for protocol_name in best_prots_found:
        identical_protocols = protocol_information[protocol_name]["identical"]
        for identical_protocol in identical_protocols:
            if identical_protocol in best_prots_found:
                print(f"Protocol {protocol_name} and {identical_protocol} are identical.")
    print("")

    # for protocol_name in best_prots_found:
    #     similar_protocols = protocol_information[protocol_name]["similar"]
    #     sim_prots = f"Protocol {protocol_name} and "
    #     for similar_protocol in similar_protocols:
    #         if similar_protocol in best_prots_found:
    #             sim_prots += f"{similar_protocol}, "
    #             # print(f"Protocol {protocol_name} and {similar_protocol} are similar.")
    #     print(sim_prots + " are similar")
    # print("")

    # for protocol_name in best_prots_found_Set3and5:
    #     similar_best_prots = []
    #     for similar_protocol in protocol_information[protocol_name]["similar"]:
    #         if similar_protocol in best_prots_found:
    #             similar_best_prots.append(similar_protocol)
    #     print(f"Protocol {protocol_name} has similar in list of best prots: {similar_best_prots}.")
    # print("")

    all_prots_found_set_five = []
    for protocol_name in protocol_information.keys():
        if protocol_name[6:8] == "sV" and protocol_name[6:9] != "sVI":
            all_prots_found_set_five.append(protocol_name)

    return best_prots_found


def get_alternative_names():
    convert_name = {'simv3_sIIIc_4_7_1': 'Septimum ($k=7$)',
                    'basic': 'Basic ($k=8$)',
                    'minimum4x_22': 'Minimum4x22 ($k=22$)',
                    'medium': 'Medium ($k=16$)',
                    'expedient': 'Expedient ($k=22$)',
                    'refined': 'Refined ($k=40$)',
                    'refined2': 'Refined ($k=40$)',
                    'minimum4x_40': 'Minimum4x40 ($k=40$)',
                    'stringent': 'Stringent ($k=42$)',
                    'simv2_sVp_4_14_1': 'New protocol ($k=14$)',
                    'simv2_sVp_4_15_1': 'New protocol ($k=15$)'}
    return convert_name


