"""
2021-2024 Sébastian de Bone (QuTech)
https://github.com/sebastiandebone/ghz_prot_II/
"""
import GHZ_prot_II.operations as ops


def nodes_overlap(nodes_1, nodes_2):
    """
    Function that finds the overlapping node for two lists with network node numbers.

    Parameters
    ----------
    nodes_1 : list of integers or None
        List with node numbers for a certain link
    nodes_2 : list of integers or None
        List with node numbers of a certain link

    Returns
    -------
    Integer indicating what the overlapping node number is between both lists
    """
    if nodes_1 is None or nodes_2 is None:
        return []
    else:
        overlap = list(set(nodes_1) & set(nodes_2))
        return overlap


def find_fused_node(nodes_nrs_left, nodes_nrs_right):
    """
    Function that finds the overlapping node for two lists with network node numbers.

    Parameters
    ----------
    nodes_nrs_left : list of integers
        List with node numbers for a certain link
    nodes_nrs_right : list of integers
        List with node numbers of a certain link

    Returns
    -------
    Integer indicating what the overlapping node number is between both lists
    """
    overlap = nodes_overlap(nodes_nrs_left, nodes_nrs_right)
    assert len(overlap) == 1, "There are more than 1 node numbers overlapping for a fusion operation."
    return overlap[0]


def find_distillation_operator(dec, inv_nodes):
    """
    Function that finds the distillation operator that needs to be applied based on a decimal value. In the return
    array "dist_op", a value "3" corresponds to a controlled-"iY" gate being applied, a value "2" to a controlled-"X"
    gate and a "1" corresponds to a controlled-"Z" gate being applied in the node in question.

    Parameters
    ----------
    dec : positive integer
        Decimal value describing the distillation operator that is measured in this distillation step
    inv_nodes : postive integer
        Node numbers of qubits of the final states after the distillation measurement (not all these qubits have to be
        part of the distillation operation themselves - only the ones getting a "1" in the array "inv_nodes_op").

    Returns
    -------
    dist_operator : string
        String describing the distillation operator in string format, including the qubit numbers on which the operator
        works (e.g., "X0 X2").
    dist_op_txt : string
        String describing the distillation operator in numeric format, excluding the qubit numbers on which the operator
        works (e.g., [1, 1]).
    """
    n = len(inv_nodes)
    dec2bin = ops.dec2signs(dec, n)
    dist_operator = ". D_op: "
    dist_op = []
    z_or_xy = dec2bin[0]
    inv_nodes_op = [0] * n
    for i in range(1, n):
        if dec2bin[i] == 1:
            inv_nodes_op[i - 1] = (inv_nodes_op[i - 1] + 1) % 2
            inv_nodes_op[i] = (inv_nodes_op[i] + 1) % 2
    for i in range(n):
        if z_or_xy == 1:
            if inv_nodes_op[i] == 1:
                dist_operator += "iY" + str(inv_nodes[i]) + " "
                dist_op.append(3)
            else:
                dist_operator += "X" + str(inv_nodes[i]) + " "
                dist_op.append(2)
        else:
            if inv_nodes_op[i] == 1:
                dist_operator += "Z" + str(inv_nodes[i]) + " "
                dist_op.append(1)

    dist_operator = dist_operator[:-1] + "."

    return dist_operator, dist_op


def remove_link_ids_from_qubit_memory(qubit_memory, link_id_list):
    for i_node, node in enumerate(qubit_memory):
        for i_qubit, qubit in enumerate(node):
            if qubit_memory[i_node][i_qubit] in link_id_list:
                qubit_memory[i_node][i_qubit] = None
    return qubit_memory
