"""
2021-2024 Sébastian de Bone (QuTech)
https://github.com/sebastiandebone/ghz_prot_II/
"""
from termcolor import colored
import dill
import os
import sys
sys.path.insert(1, '.')


class TimeBlock:
    def __init__(self, subsystem=None, elementary_links=0):
        self.subsystem = subsystem
        self.list_of_operations = []
        self.elem_links = elementary_links
        self.distill_blockage = False


class SubSystem:
    def __init__(self, nodes, n):
        self.nodes = nodes
        self.concurrent_subsystems = self._calculate_concurrent_subsystems(n)

    def _calculate_concurrent_subsystems(self, n):
        all_nodes = [*range(n)]
        if len(self.nodes) == 2:
            overlap = [list(set(all_nodes) ^ set(self.nodes))]
        elif len(self.nodes) == 1:
            overlap = [[i] for i in set(all_nodes) ^ set(self.nodes)]
        else:
            overlap = []
        return overlap


class Operation:
    def __init__(self, nodes, electron_qubits, memory_qubits, type, operator=None, subsystem=None,
                 link_id=None, children=None, nodes_final_state=None, family_tree=None, frl=None,
                 fusion_corrections=None, delay_after_sub_block=False, i_op=None, frl_id=None, state_mem=None):
        self.nodes = nodes
        self.e_qubits = electron_qubits
        self.m_qubits = memory_qubits
        self.type = type
        self.operator = operator
        self.subsystem = subsystem
        self.link_id = link_id
        self.children = children
        self.nodes_state = nodes_final_state
        self.family_tree = family_tree
        self.fr_list = family_tree
        self.frl = frl
        self.frl_id = frl_id
        self.fusion_corrections = [] if fusion_corrections is None else fusion_corrections
        self.success_dep = [link_id]
        self.delay_after_sub_block = delay_after_sub_block
        self.i_op = i_op
        self.state_mem = state_mem

    def print_operation(self):
        print_line = colored("Node(s) " + str(self.nodes), "blue") + ": " + "(" + str(self.i_op)
        if self.link_id is not None:
            print_line += ", " + colored(str(self.link_id), "cyan")
        print_line += ") " + colored(self.type, "magenta")
        if self.type == "CREATE_LINK":
            print_line += " between qubits " + str(self.e_qubits)
        elif self.type == "SWAP":
            print_line += " qubit " + str(self.e_qubits) + " and " + str(self.m_qubits)
        elif self.type == "DISTILL":
            print_line += " operation " + str(self.operator) + " by measuring out qubits " \
                          + str(self.e_qubits) + " and keeping qubits " + str(self.m_qubits)
        elif self.type == "FUSE":
            print_line += " by measuring out qubits " + str(self.e_qubits) + " and keeping qubits " + str(self.m_qubits)
        print_line += "."
        if self.type == "DISTILL" or self.type == "FUSE":
            print_line += " ("
        if self.children is not None:
            print_line += colored("Subtree: " + str(self.family_tree), "yellow") + ", "
            # print_line += "Children: " + str(operation.children) + ", "
        if self.type == "DISTILL":
            print_line += colored("FR ids: " + str(self.fr_list), "green") + ", "
            print_line += "FRL level: " + str(self.frl) + ", "
            print_line += "FRL ID: " + str(self.frl_id) + ", "
            print_line += "Success dep.: " + str(self.success_dep) + ", "
            print_line += "Delay: " + str(self.delay_after_sub_block) + ", "
        if self.type == "DISTILL" or self.type == "FUSE":
            print_line += "Mems state: " + str(self.state_mem) + ", "
        if self.type == "FUSE":
            qubits = [fc.qubit for fc in self.fusion_corrections]
            operator = [int(len(fc.condition[0]) > 0) + 2 * int(len(fc.condition[1]) > 0) for fc
                        in self.fusion_corrections]
            print_line += "Correction qubits: " + str(qubits) + ", Correction operator: " + str(operator) + ", "
        if self.type == "DISTILL" or self.type == "FUSE":
            print_line = print_line[:-2] + ".)"
        print(print_line)


class FusionCorrection:
    def __init__(self, qubit, condition=None):
        if condition is None:
            condition = [[], []]
        self.qubit = qubit
        self.condition = condition

    def print_fusion_correction(self):
        print_line = colored("CORRECT", "magenta") + " qubit " + str(self.qubit)
        if self.condition[0]:
            print_line += " with operator Z conditioned on result(s) "
            print_line += str(self.condition[0])
            if self.condition[1]:
                print_line += " and"
            else:
                print_line += "."
        if self.condition[1]:
            print_line += " with operator X conditioned on result(s) "
            print_line += str(self.condition[1]) + "."
        print(print_line)


# EXPORT PROTOCOL_RECIPES OBJECTS AND FUNCTIONALITIES:

class ProtocolRecipeExport:
    # Exact copy of the "simulate_protocol.protocol_recipe.ProtocolRecipe" class, which is initialized completely empty.
    # The parameters of the ProtocolRecipe class are then copied to an object of this class using the
    # "export_protocol_recipe" function below. If then consequently dump this ProtocolRecipeExport object, we can load
    # it anywhere without requiring the files linking to the classes, if we use the module "dill" to do this dumping
    # and loading.

    def __init__(self):
        # These are parameters of the original object ProtocolRecipe that are not copied to ProtocolRecipeExport:
        # self.protocol = None
        # self.all_operations = None
        # self.link_values = None

        self.n = None
        self.qubit_memory = None
        self.swap_count = None

        self.id_elem = None
        self.id_linked = None
        self.link_children = None
        self.link_parent_id = None
        self.link_nodes = None
        self.link_memory_loc = None
        self.qubit_memory = None
        self.qubit_memory_per_time_step = None
        self.id_link_structure = None

        self.time_blocks = []
        self.delayed_distillation_check = []
        self.fusion_corrections = []
        self.subsystems = {}


def export_protocol_recipe(protocol_recipe):
    protocol_recipe_export = ProtocolRecipeExport()
    protocol_recipe_export.n = protocol_recipe.n
    protocol_recipe_export.swap_count = protocol_recipe.swap_count
    protocol_recipe_export.qubit_memory = protocol_recipe.qubit_memory

    protocol_recipe_export.id_elem = protocol_recipe.id_elem
    protocol_recipe_export.id_linked = protocol_recipe.id_linked
    protocol_recipe_export.link_children = protocol_recipe.link_children
    protocol_recipe_export.link_parent_id = protocol_recipe.link_parent_id
    protocol_recipe_export.link_nodes = protocol_recipe.link_nodes
    protocol_recipe_export.link_memory_loc = protocol_recipe.link_memory_loc
    protocol_recipe_export.qubit_memory_per_time_step = protocol_recipe.qubit_memory_per_time_step
    protocol_recipe_export.id_link_structure = protocol_recipe.id_link_structure

    for i_ts, ts in enumerate(protocol_recipe.time_blocks):
        protocol_recipe_export.time_blocks.append([TimeBlock(), TimeBlock()])
        protocol_recipe_export.delayed_distillation_check.append({})
        protocol_recipe_export.fusion_corrections.append({})
        for i_ssys, ssys in enumerate(ts):
            if ssys.list_of_operations:
                protocol_recipe_export.time_blocks[i_ts][i_ssys].subsystem = SubSystem(ssys.subsystem.nodes,
                                                                                       protocol_recipe.n)
                protocol_recipe_export.time_blocks[i_ts][i_ssys].elem_links = ssys.elem_links
                for i_op, operation in enumerate(ssys.list_of_operations):
                    operation_new = Operation(operation.nodes, operation.e_qubits, operation.m_qubits, operation.type)
                    operation_new.operator = operation.operator
                    operation_new.subsystem = SubSystem(operation.subsystem.nodes, protocol_recipe.n)
                    operation_new.link_id = operation.link_id
                    operation_new.children = operation.children
                    operation_new.nodes_state = operation.nodes_state
                    operation_new.family_tree = operation.family_tree
                    operation_new.fr_list = operation.fr_list
                    operation_new.frl = operation.frl
                    operation_new.frl_id = operation.frl_id
                    operation_new.success_dep = operation.success_dep
                    operation_new.delay_after_sub_block = operation.delay_after_sub_block
                    operation_new.i_op = operation.i_op
                    operation_new.state_mem = operation.state_mem
                    if operation.fusion_corrections:
                        for fusion_correction in operation.fusion_corrections:
                            operation_new.fusion_corrections.append(FusionCorrection(fusion_correction.qubit,
                                                                                     fusion_correction.condition))
                    protocol_recipe_export.time_blocks[i_ts][i_ssys].list_of_operations.append(operation_new)
        if protocol_recipe.delayed_distillation_check[i_ts]:
            for dist_id in protocol_recipe.delayed_distillation_check[i_ts]:
                protocol_recipe_export.delayed_distillation_check[i_ts][dist_id] = \
                    protocol_recipe.delayed_distillation_check[i_ts][dist_id]
        if protocol_recipe.fusion_corrections[i_ts]:
            for fc_qubit in protocol_recipe.fusion_corrections[i_ts]:
                protocol_recipe_export.fusion_corrections[i_ts][fc_qubit] = \
                    FusionCorrection(protocol_recipe.fusion_corrections[i_ts][fc_qubit].qubit,
                                     protocol_recipe.fusion_corrections[i_ts][fc_qubit].condition)

    for sub_sys in protocol_recipe.subsystems.keys():
        protocol_recipe_export.subsystems[sub_sys] = SubSystem(protocol_recipe.subsystems[sub_sys].nodes,
                                                               protocol_recipe.n)

    return protocol_recipe_export


if __name__ == "__main__":
    # This part is used to generate an exact copy of the ProtocolRecipe class, but solely based on elements from this
    # file itself, which can then be exported using "dill". This creates an object that can be loaded with "dill"
    # anywhere, and doesn't depend on the underlying files. These protocols are saved in the
    # "results/protocols/ProtocolRecipeExport" folder of the directory.

    folder = "results/protocols/ProtocolRecipe/"
    folder_export = "results/protocols/ProtocolRecipeExport/"
    files = os.listdir(folder)

    # Here, we scan the "protocols/ProtocolRecipe/" folder, and convert all
    # simulate_protocol.protocol_recipe.ProtocolRecipe objects to ProtocolRecipeExport objects from this file.
    for prot_file in files:
        if not os.path.isfile(folder_export + prot_file):
            try:
                prot = dill.load(open(folder + prot_file, "rb"))
                protocol_recipe_export = export_protocol_recipe(prot)
                dill.dump(protocol_recipe_export, open(folder_export + prot_file, "wb"))
                print(f"Created file {prot_file}.")
            except:
                print(f"Couldn't convert file {prot_file}.")
        else:
            print(f"File {prot_file} already existed.")
