"""
2020-2024 Runsheng Ouyang (QuTech)
https://github.com/sebastiandebone/ghz_prot_I
"""
import sys
import matplotlib.pyplot as plt

import GHZ_prot_II.operations as op


plt.rcParams["font.family"] = "Calibri"
plt.rcParams["font.style"] = "normal"
#plt.rcParams["font.weight"] = "100"
plt.rcParams["font.stretch"] = "normal"
plt.rcParams["font.size"] = 18
plt.rcParams["lines.linewidth"] = 1.2
plt.rcParams["axes.linewidth"] = 0.8
plt.rcParams["grid.linewidth"] = 0.4
plt.rcParams.update({'figure.autolayout': True})
d_hor = 4
d_vec = 8
radius = 2
fontsize = 20


def convert_to_letters(i):
    letters = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W']
    if isinstance(i, list):
        length_list = len(i)
        input_list = i
    else:
        length_list = 1
        input_list = [i]
    output = ''
    for j in range(length_list):
        output = output + letters[input_list[j]]
    return output


def get_left_width(root):
    return get_width(root.left)


def get_right_width(root):
    return get_width(root.right)


def get_width(root):
    if root == None:
        return 0
    return get_width(root.left) + 1 + get_width(root.right)


def get_height(root):
    if root == None:
        return 0
    return max(get_height(root.left), get_height(root.right)) + 1


def get_w_h(root):
    w = get_width(root)
    h = get_height(root)
    return w, h


def draw_a_node(x, y, s, ax):
    c_node = plt.Circle((x,y), radius=radius, color='yellow')
    ax.add_patch(c_node)
    plt.text(x, y, s, ha='center', va='bottom', fontsize=fontsize)


def draw_a_edge(x1, y1, x2, y2):
    x = (x1, x2)
    y = (y1, y2)
    plt.plot(x, y, 'k-')
    plt.axis('off')


def create_win(root):
    WEIGHT, HEIGHT = get_w_h(root)
    WEIGHT = (WEIGHT+1)*d_hor
    HEIGHT = (HEIGHT+1)*d_vec
    fig = plt.figure(figsize=(20, 20))
    ax = fig.add_subplot(111)
    plt.xlim(0, WEIGHT)
    plt.ylim(0, HEIGHT)
    plt.axis('off')

    x = (get_left_width(root) + 1) * d_hor
    y = HEIGHT - d_vec
    return fig, ax, x, y


def print_tree_by_inorder(root, x, y, ax, print_reg_slot, print_id):
    if root == None:
        return
    if (root.value.n != 2) | (root.value.k != 1):
        if print_reg_slot:
            s = '(' + str(root.value.n) + ',' + str(root.value.k) + ',' + str(root.value.t) + ')'
        else:
            s = '(' + str(root.value.n) + ',' + str(root.value.k) + ')'
        if root.value.p_or_f == 0:
            if root.value.r2 != 0:
                s = s + '\nr2 = ' + str(root.value.r2)
            dec2bin = op.dec2signs(root.value.dec, root.value.n)
            s = s + '\npurification\n'
            qubits_inv = [0] * root.value.n
            z_or_xy = dec2bin[0]
            for i in range(1, root.value.n):
                if dec2bin[i] == 1:
                    qubits_inv[i - 1] = (qubits_inv[i - 1] + 1) % 2
                    qubits_inv[i] = (qubits_inv[i] + 1) % 2
            for i in range(root.value.n):
                if z_or_xy == 1:
                    if qubits_inv[i] == 1:
                        s = s + 'iY' + str(i + 1) + ' '
                    else:
                        s = s + 'X' + str(i + 1) + ' '
                else:
                    if qubits_inv[i] == 1:
                        s = s + 'Z' + str(i + 1) + ' '
        else:
            if root.value.r1 != 0 or root.value.r2 != 0:
                s = s + '\n(r1,r2) = (' + str(root.value.r1) + ',' + str(root.value.r2) + ')'
            s = s + '\nfusion in node ' + convert_to_letters(root.node_nrs[(root.left.value.n - root.value.i - 1)])
            # s = s + '\n(q1,q2) = (' + str(root.left.value.n - root.value.i) + ',' + str(root.value.j + 1) + ')'
        if print_id:
            s = s + "\nid = " + str(root.id)
    else:
        if print_reg_slot:
            s = '(' + str(root.value.n) + ',' + str(root.value.k) + ',' + str(root.value.t) + ')'
        else:
            s = '(' + str(root.value.n) + ',' + str(root.value.k) + ')'
        s = s + '\nelem. link'
        if print_id:
            s = s + '\nid = ' + str(root.id)
    if root.node_nrs != None:
        # root.node_nrs.sort()
        root.node_nrs
        s = s + "\n" + convert_to_letters(root.node_nrs)
    draw_a_node(x, y, s, ax)
    lx = rx = 0
    ly = ry = y - d_vec
    if root.left != None:
        lx = x - d_hor * (get_right_width(root.left) + 1)
        draw_a_edge(x, y, lx, ly)
    if root.right != None:
        rx = x + d_hor * (get_left_width(root.right) + 1)
        draw_a_edge(x, y, rx, ry)
    #recursion
    print_tree_by_inorder(root.left, lx, ly, ax, print_reg_slot, print_id)
    print_tree_by_inorder(root.right, rx, ry, ax, print_reg_slot, print_id)


def find_F(root):
    node = root
    while (node.value.n != 2)|(node.value.k != 1):
        node = node.left
    if isinstance(node.value.state, list):
        return node.value.state[0]
    else:
        return node.value.state


def plot_protocol(root, ms_or_not, include_counter=0, counter=0, name=None, print_reg_slot=True, print_id=True):
    # 1, ms; 0, not ms
    _, ax, x, y = create_win(root)
    print_tree_by_inorder(root, x, y, ax, print_reg_slot, print_id)
    if ms_or_not == 1:
        if name == None:
            if root.value.t1 != None:
                # plt.title('Protocol ('+str(root.value.n) + ',' + str(root.value.k) + ') for F = ' + str(find_F(root)))
                plt.savefig('protocol_(' + str(root.value.n) + ',' + str(root.value.k) + ')_found_for_F='
                        + str(find_F(root)) + "_" + str(counter) + ".pdf")
            else:
                sys.exit('this tree is not ms, please set ms_or_not = 0')
        else:
            plt.savefig(name + ".pdf")
    else:
        if name == None:
            # plt.title('Protocol ('+str(root.value.n) + ',' + str(root.value.k) + ') for F = ' + str(find_F(root)))
            plt.savefig('protocol_(' + str(root.value.n) + ',' + str(root.value.k) + ')_found_for_F='
                        + str(find_F(root)) + "_" + ".pdf")
        else:
            plt.savefig(name + ".pdf")
    return
