"""
2022-2024 Sebastian de Bone (QuTech)
https://github.com/sebastiandebone/ghz_prot_II
_____________________________________________
"""
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from PIL import Image
import math
from collections import defaultdict

from analysis.plot_settings import plot_settings, plot_fonts


def plot_cut_off_time_plot(
        cut_off_dict_set_plot, convert_name,
        cut_off_information_plot, plot_type,
        # prot_cut_off_time_dependence,
        # plot_cut_off_time_dependence_subplots, plot_cut_off_subplot_one_figure,
        plot_cot_filtered_out_cots,
        plot_style, number_scaling_tries, font_size_big_start, font_size_small_start,
        plot_cot_dependence_protocol,
        figure_scale_factor,
        **kwargs
):
    if cut_off_dict_set_plot is not None:
        cut_off_dict_set = cut_off_dict_set_plot
        cut_off_information = cut_off_information_plot
    else:
        cut_off_dict_set = None
        cut_off_information = None

    print(f"cut_off_information_plot = {cut_off_information}")
    print(f"cut_off_dict_set_plot = {cut_off_dict_set}")
    markers = ["o", "s", "v", "D", "p", "^", "h", "X", "<", "P", "*", ">", "H", "d", 4, 5, 6, 7, 8, 9, 10, 11]
    for i in range(number_scaling_tries):
        font_size_big = font_size_big_start * figure_scale_factor
        font_size_small = font_size_small_start * figure_scale_factor

        if plot_type == "cot_dependence_subplots":
            fig, ax = plt.subplots(2, 2, figsize=(13.33339, 11)) if plot_style == "paper" \
                else plt.subplots(2, 2, figsize=(9.48290, 8.5))
            colors = plot_settings(axes=[ax[0, 0], ax[0, 1], ax[1, 0], ax[1, 1]], style=plot_style, number_colors=4,
                                   set_grid=True)
            fig.suptitle("GHZ cycle time dependence on the surface code threshold", fontsize=font_size_big)
        elif plot_type == "cot_dependence_one_figure":
            fig = plt.figure(figsize=(6.4314, 6.7)) if plot_style == "paper" \
                else plt.figure(figsize=(0.8 * 9.48290, 7.7))
            ax = plt.gca()
            colors = plot_settings(axes=[ax], style=plot_style, number_colors=4, set_grid=True)
        else:
            fig = plt.figure(figsize=(6.4314, 5))
            ax = plt.gca()
            colors = plot_settings(axes=[ax], style=plot_style, number_colors=4, set_grid=True)

        colors = [value for value in colors.values()]

        for subplot_nmb in [0, 3, 1, 2]:  # range(len(cut_off_dict_set.items())):
            cut_off_dict = cut_off_dict_set[subplot_nmb]
            x_values = sorted(list(cut_off_dict.keys()))
            x_values = [cot for cot in x_values if cot not in plot_cot_filtered_out_cots]
            threshold_values = [cut_off_dict[cot]['threshold'] for cot in x_values]
            y_values = [v[0] for v in threshold_values]
            y_err_below = [v[0] - v[1][0] for v in threshold_values]
            y_err_above = [v[1][1] - v[0] for v in threshold_values]
            y_err = [y_err_below, y_err_above]

            if plot_type == "cot_dependence_subplots":
                sp_row = int(subplot_nmb / 2)
                sp_col = subplot_nmb % 2
                plot_object = ax[sp_row, sp_col]
                axis = plot_object
            else:
                plot_object = plt
                axis = ax

            if plot_type in ["cot_dependence_subplots", "cot_dependence_one_figure"]:
                protocol = convert_name[cut_off_information[subplot_nmb][0]]
            else:
                protocol = plot_cot_dependence_protocol

            offset = 0.000015
            GHZ_rates_data = []
            for i_x, x in enumerate(x_values):
                GHZ_rates = cut_off_dict[x]['GHZ_success_rates']
                threshold_value = y_values[i_x]
                key_to_take = min(GHZ_rates.keys(), key=lambda x: abs(x - threshold_value))
                GHZ_rate_threshold = GHZ_rates[key_to_take] if key_to_take in GHZ_rates.keys() else 0
                if round(GHZ_rate_threshold, 3) >= 0.9999:
                    if x == 0.0442801:
                        GHZ_rates_data.append(99.999)
                    elif x == 0.0904881:
                        GHZ_rates_data.append(99.991)
                    elif x == 0.0977841:
                        GHZ_rates_data.append(99.997)
                    else:
                        GHZ_rates_data.append(round(GHZ_rate_threshold * 100, 2))
                else:
                    GHZ_rates_data.append(round(GHZ_rate_threshold * 100, 1))
                # GHZ_success_rate = round(GHZ_rate_threshold * 100, 3)
                # y_value = threshold_values[i_x][1][0] - offset if i_x != 4 else threshold_values[i_x][1][1] + offset/3
                # plt.text(x, y_value, f"{GHZ_success_rate}%", size=12, ha='center')

            if plot_type == "cot_dependence_one_figure":
                x_values = GHZ_rates_data
                x_values_new, y_values_new, y_err_below_new, y_err_above_new = [], [], [], []
                for i_value, x_value in enumerate(x_values):
                    if x_value >= 86:
                        x_values_new.append(x_value)
                        y_values_new.append(y_values[i_value])
                        y_err_below_new.append(y_err[0][i_value])
                        y_err_above_new.append(y_err[1][i_value])
                x_values, y_values, y_err = x_values_new, y_values_new, [y_err_below_new, y_err_above_new]

            labelplot = f"{protocol}"
            colorplot = 'blue'
            markerplot = 'o'
            colorline = 'grey'
            alphaline = 1
            widthline = 1
            markersizeplot = 5
            markerfillstyle = 'full'
            markeralpha = 1
            if plot_type in ["cot_dependence_subplots", "cot_dependence_one_figure"]:
                # protocol = convert_name[cut_off_information[subplot_nmb][0]]
                par_value = cut_off_information[subplot_nmb][2]
                fig_information = cut_off_information[subplot_nmb][3]
                titleplot = protocol + ' at ' + r'$' + str(par_value) + r'$' + f' in {fig_information}'
                colorplot = colors[subplot_nmb]
                markerplot = markers[subplot_nmb]
                colorline = colorplot
                alphaline = 0.4
                widthline = 1.5
                markerfillstyle = 'none'
                markersizeplot = 8
                markeralpha = 0.8
                if plot_type == "cot_dependence_one_figure":
                    labelplot = titleplot
                    if plot_style == "paper":
                        titleplot = f'GHZ completion probability dependence on threshold'
                    else:
                        titleplot = f'GHZ completion probability dependence on the surface code threshold'
            else:
                titleplot = f'GHZ cycle time dependence on threshold'  # {protocol.replace("_","-")}'

            plot_object.plot(x_values, y_values, color=colorline, linestyle=':', alpha=alphaline,
                             linewidth=widthline)
            plot_object.errorbar(x_values, y_values, yerr=y_err, markersize=markersizeplot, color=colorplot,
                                 linestyle='None', marker=markerplot, fillstyle=markerfillstyle,
                                 label=labelplot, alpha=markeralpha)

            # plot_object.grid(axis='y', alpha=0.75)
            # plot_object.grid(axis='x', alpha=0.75)
            if plot_type == "cot_dependence_one_figure":
                xlabel = r'GHZ completion probability $p_\mathrm{GHZ}$ (%)'
            else:
                xlabel = r'GHZ cycle time $t_\mathrm{GHZ}$ (s)'
            ylabel = r'Error probability threshold $p_\mathrm{g}=p_\mathrm{m}$ (%)'
            # axis.grid(color='0.85', linestyle='-', linewidth=1)

            if plot_type == "cot_dependence_one_figure":
                # plot_object.xlim([89, 100])
                # plot_object.xscale('log')
                legendlabelspacing = 0.05 if plot_style == "paper" else 0.085
                plot_object.legend(fontsize=font_size_small, loc='upper center', ncol=1,
                                   labelspacing=legendlabelspacing, bbox_to_anchor=(0, -0.32, 1, 0.2))  # bbox_to_anchor=(0, 1.02, 1, 0.2)) #, mode='expand')
                # title_shift_upwards = 1.33 if plot_style == "paper" else 1.28
                # plot_object.title(titleplot, fontsize=font_size_big, y=title_shift_upwards)
                plot_object.title(titleplot, fontsize=font_size_big)
                axis.set_xticks([86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100])
            else:
                axis.set_title(titleplot, fontsize=font_size_big)
            # if not (plot_cut_off_time_dependence_subplots and sp_row in [0, 1]):
            axis.set_xlabel(xlabel, fontsize=font_size_small)
            if not (plot_type == "cot_dependence_subplots" and sp_col == 1):
                axis.set_ylabel(ylabel, fontsize=font_size_small)

            # axis.patch.set_facecolor('0.97')
            # axis.spines["top"].set_visible(False)
            # axis.spines["right"].set_visible(False)
            # axis.spines["bottom"].set_visible(False)
            # axis.spines["left"].set_visible(False)

            def update_ticks_y(y, pos):
                if y == 0:
                    return 0
                else:
                    return round(y * 100, 2)

            min_y_tick, max_y_tick = axis.get_ylim()
            min_y_tick = math.ceil(min_y_tick * 1e4) / 1e4
            max_y_tick = math.floor(max_y_tick * 1e4) / 1e4
            tick_spacing = 2e-4 if plot_type == "cot_dependence_one_figure" else 1e-4
            # tick_spacing = 2e-4 if (plot_cut_off_time_dependence_subplots and plot_cut_off_subplot_one_figure) \
            #     else 1e-4
            axis.set_yticks(np.arange(min_y_tick, max_y_tick + 1e-4, tick_spacing))
            axis.yaxis.set_major_formatter(mpl.ticker.FuncFormatter(update_ticks_y))
            if not plot_type == "cot_dependence_one_figure":
                ax2 = axis.twiny()
                ax2.spines["top"].set_visible(False)
                ax2.spines["right"].set_visible(False)
                ax2.spines["bottom"].set_visible(False)
                ax2.spines["left"].set_visible(False)
                ax2.set_xbound(axis.get_xbound())
                # ax.set_xticks(minor_ticks, minor=True)
                ax2.set_xlabel(r'GHZ completion probability $p_\mathrm{GHZ}$ (%)', fontsize=font_size_small)
                ax2.set_xticks(x_values)
                ax2.set_xticks([], minor=True)
                ax2.set_xticklabels(GHZ_rates_data, fontsize=font_size_small * 0.93, rotation=45)

            # axis.tick_params(axis='both', labelsize=font_size_small)

        plt.tight_layout()
        if plot_style != "dissertation":
            w, h = fig.get_size_inches()
            plt.savefig("generate_data_and_figures/figures/plot_" + plot_type + f"_{i}.png", bbox_inches='tight', pad_inches=0)
            im = Image.open("generate_data_and_figures/figures/plot_" + plot_type + f"_{i}.png")
            figure_scale_factor = im.size[0] / (w * fig.dpi)
            print(im.size[0], w * fig.dpi, figure_scale_factor)

    if plot_style != "dissertation":
        plt.savefig("generate_data_and_figures/figures/plot_" + plot_type + ".pdf", bbox_inches='tight', pad_inches=0)
    else:
        plt.savefig("generate_data_and_figures/figures/plot_" + plot_type + ".pdf")
    plt.show()


def plot_main_function(
        protocols_to_be_plotted_per_set, protocols_calculated_per_set, convert_name,
        plot_type, number_scaling_tries, font_size_big_start, font_size_small_start, plot_style,
        figure_scale_factor,
        **kwargs
):
    markers = ["o", "s", "v", "D", "p", "^", "h", "X", "<", "P", "*", ">", "H", "d", 4, 5, 6, 7, 8, 9, 10, 11]
    if plot_type == "link_efficiency":
        # sets_link_efficiency = {'Set3b': 200, 'Set3d': 500, 'Set3c': 800, 'Set3e': 2000, 'Set3f': 20000,
        #                         'Set3g': 200000, 'Set3m': 322581, 'Set3h': 10000000, 'Set3k': 100000000}
        # sets_in_plot = {'Set3u': 300, 'Set3r': 400, 'Set3d': 500, 'Set3q': 600, 'Set3c': 800, 'Set3e': 2000,
        #                 'Set3f': 20000, 'Set3g': 200000, 'Set3h': 10000000, 'Set3k': 100000000}
        sets_in_plot = {'Set3r': 400, 'Set3d': 500, 'Set3q': 600, 'Set3c': 800, 'Set3e': 2000,
                        'Set3f': 20000, 'Set3g': 200000, 'Set3h': 10000000, 'Set3k': 100000000}
    elif plot_type == "bell_succ_prob":
        # f_eta_values = {'Set6f': 5, 'Set6k': 5.5, 'Set6g': 6, 'Set6m': 6.5, 'Set6h': 7, 'Set3e': 8}
        f_eta_values = {'Set6k': 5.5, 'Set6g': 6, 'Set6m': 6.5, 'Set6h': 7, 'Set3e': 8}
        # p_link values:
        sets_in_plot = {k: 0.002 * 10 ** (4 + 0.25 * v) for k, v in f_eta_values.items()}
        # eta values:
        # sets_in_plot = {'Set6d': 0.1061, 'Set6e': 0.141421, 'Set6f': 0.1886, 'Set6k': 0.2178, 'Set6g': 0.2516,
        #                         'Set6m': 0.2904, 'Set6h': 0.3354}
        # eta^2 values:
        # sets_in_plot = {k: v**2 for k, v in sets_in_plot.items()}
        top_ticks_labels = [round(math.sqrt(2) * 10 ** (0.125 * v - 1.5), 4) for v in f_eta_values.values()]
    elif plot_type == "bell_quality":
        # f_phi_values = {'Set5u': -4, 'Set5r': -3, 'Set5q': -2, 'Set5p': -1, 'Set5a': 0, 'Set5b': 1, 'Set5c': 2, 'Set5d': 3, 'Set5e': 4}
        f_phi_values = {'Set5r': -3, 'Set5q': -2, 'Set5p': -1, 'Set5a': 0, 'Set5b': 1, 'Set5c': 2, 'Set5d': 3,
                        'Set5e': 4}
        phi_values = {k: 0.84 + 0.03 * v for k, v in f_phi_values.items()}
        sets_in_plot = {k: 0.5 * (1 + v ** 2) for k, v in phi_values.items()}
        top_ticks_labels = [round(1 - (v / (math.sqrt(0.95) * (2 * 0.999 - 1) ** 2)) ** (1 / 2), 3) for v in
                            phi_values.values()]
        # sets_in_plot = {'Set5p': 0.086556, 'Set5a': 0.069796, 'Set5b': 0.0533309, 'Set5c': 0.037147,
        #                         'Set5d': 0.021231, 'Set5e': 0.0055701}
        # top_ticks_labels = [round(v,2) for k, v in phi_values.items()]
    if protocols_to_be_plotted_per_set is None:
        protocols_calculated_per_set = dict(protocols_calculated_per_set)
    else:
        protocols_calculated_per_set = protocols_to_be_plotted_per_set
    print(f"\n\n\n\nProtocols to be plotted per set: {protocols_calculated_per_set}.")

    y_values = {}
    # for set in sets_link_efficiency.keys():
    #     best_thresholds_per_set[set] = best_thresholds_per_set['Set3c']
    # best_thresholds_per_set['Set3e'] = [('simv3_sIIIc_4_7_1', 0.0226081, [0.00166, (0.001640, 0.001685)])]

    for i in range(number_scaling_tries):
        font_size_big = font_size_big_start * figure_scale_factor
        font_size_small = font_size_small_start * figure_scale_factor

        if plot_style == "paper":
            fig = plt.figure(figsize=(6.4314, 6))
        elif plot_style == "dissertation":
            fig = plt.figure(figsize=(0.8 * 9.48290, 7))
        colors = plot_settings(style=plot_style, number_colors=len(markers), set_grid=True, set_minor_grid=True)
        colors = [value for value in colors.values()]
        # fig = plt.figure(figsize=(8.4314, 6))
        # plt.rc('mathtext', fontset='cm')  # 'dejavusans', 'dejavuserif', 'cm', 'stix', and 'stixsans'.
        # plt.rc('font', family='cmss10')

        prots = defaultdict(dict)

        # # Add threshold values zero to protocols that doesn't have a threshold in a set:
        # all_protocols_in_plot = []
        # for set in protocols_calculated_per_set.keys():
        #     for prot in protocols_calculated_per_set[set]:
        #         if prot[0] not in all_protocols_in_plot:
        #             all_protocols_in_plot.append(prot[0])
        #             # for set_acc in sets_link_efficiency.keys():
        #             for set_acc in sets_in_plot.keys():
        #                 prots[prot[0]][sets_in_plot[set_acc]] = [-0.0004, (0, 0)]

        for set in sets_in_plot.keys():
            if set in protocols_calculated_per_set.keys():
                for prot in protocols_calculated_per_set[set]:
                    # prots[(prot[0], prot[1])][sets_link_efficiency[set]] = prot[2]
                    prots[prot[0]][sets_in_plot[set]] = prot[2]

        i_col_mar = 0
        for prot in prots.keys():
            color = colors[i_col_mar]
            marker = markers[i_col_mar]
            i_col_mar += 1
            label = convert_name[prot] if prot in convert_name.keys() else prot
            # label = prot.replace("_", "-")
            x_values = prots[prot].keys()
            y_values = [prots[prot][x][0] for x in x_values]
            y_err_below = [prots[prot][x][0] - prots[prot][x][1][0] for x in x_values]
            y_err_above = [prots[prot][x][1][1] - prots[prot][x][0] for x in x_values]
            y_err = [y_err_below, y_err_above]
            if prot == "Best protocol GHZ optimization":
                if plot_type == "link_efficiency":
                    # This loop can be removed (to remove the new protocol data points without decoherence)
                    x_values = list(x_values)[:-2]
                    y_values = y_values[:-2]
                    y_err = [y_err_below[:-2], y_err_above[:-2]]
                    # GIBSON
                    k_values = [6, 6, 7, 7, 10, 10, 10]
                    for i_value in range(len(x_values)):
                        plt.text(x_values[i_value] - x_values[i_value] / 7, y_values[i_value],
                                 str(k_values[i_value]), size=font_size_small, ha='right')
                if plot_type == "bell_succ_prob":
                    k_values = [6, 7, 7, 7, 10]
                    for i_value in range(len(x_values)):
                        plt.text(list(x_values)[i_value] - list(x_values)[i_value] / 30, y_values[i_value],
                                 str(k_values[i_value]), size=font_size_small, ha='right')
                if plot_type == "bell_quality":
                    k_values = [12, 12, 12, 11, 11, 11, 10, 7]
                    for i_value in range(len(x_values)):
                        plt.text(list(x_values)[i_value] - 0.003, y_values[i_value],
                                 str(k_values[i_value]), size=font_size_small, ha='right')
            # elif prot != "Best protocol GHZ optimization" and plot_link_efficiency_thresholds:
            #     x_values = [xvalue for i_x, xvalue in enumerate(x_values) if i_x not in [3, 4, 5, 6]]
            #     y_values = [yvalue for i_y, yvalue in enumerate(y_values) if i_y not in [3, 4, 5, 6]]
            #     y_err = [[yerrb for i_yerrb, yerrb in enumerate(y_err_below) if i_yerrb not in [3, 4, 5, 6]], [yerra for i_yerra, yerra in enumerate(y_err_above) if i_yerra not in [3, 4, 5, 6]]]
            plt.plot(x_values, y_values, color=color, linestyle=':', alpha=0.4, linewidth=1.5)
            plt.errorbar(x_values, y_values, yerr=y_err, markersize=8, fillstyle="none", color=color,
                         linestyle='None', marker=marker, label=f'{label}', alpha=0.8)

        # fig = plt.plot(x_values, y_values, color='grey', linestyle=':')
        # plt.errorbar(x_values, y_values, yerr=y_err, markersize=5, color='blue', linestyle='None', marker='o', label=f'{protocol}')

        # offset = 0.000015
        # for i_x, x in enumerate(x_values):
        #     GHZ_success_rate = round(thresholds_dict[protocol][x]['GHZ_success_rates'][y_values[i_x]]*100, 3)
        #     y_value = threshold_values[i_x][1][0] - offset if i_x != 4 else threshold_values[i_x][1][1] + offset/3
        #     plt.text(x, y_value, f"{GHZ_success_rate}%", size=12, ha='center')

        ax = plt.gca()
        ax.set_xscale('log') if (plot_type == "link_efficiency" or plot_type == "bell_succ_prob") else None
        # plt.grid(axis='y', alpha=0.75)
        # plt.grid(axis='x', alpha=0.75)
        if plot_type == "link_efficiency":
            xlabel = r'Link efficiency $\eta^{\ast}_\mathrm{link}$'
            title = f'Thresholds at different link efficiency values'
        elif plot_type == "bell_succ_prob":
            xlabel = r'Link efficiency $\eta^{\ast}_\mathrm{link}$'
            title = f'Thresholds at different entanglement success rates'
        elif plot_type == "bell_quality":
            xlabel = r'Bell pair fidelity $F_\mathrm{link}$'
            title = f'Thresholds at different Bell pair fidelities'

        ylabel = r'Error probability threshold $p_\mathrm{g}=p_\mathrm{m}$ (%)'
        # plt.legend(bbox_to_anchor=(1.04,1), loc="upper left", prop={'size': 13})
        plt.legend(prop={'size': font_size_small})
        # ax.grid(color='0.85', linestyle='-', linewidth=1)
        ax.set_title(title, fontsize=font_size_big)
        ax.set_xlabel(xlabel, fontsize=font_size_small)
        ax.set_ylabel(ylabel, fontsize=font_size_small)
        # ax.patch.set_facecolor('0.97')
        # ax.spines["top"].set_visible(False)
        # ax.spines["right"].set_visible(False)
        # ax.spines["bottom"].set_visible(False)
        # ax.spines["left"].set_visible(False)

        x_limits = plt.gca().get_xlim()
        ypos = plt.gca().get_ylim()  # y position of the "break"
        if plot_type == "link_efficiency":
            break_loc_x = [math.sqrt(10) * 10 ** 6, math.sqrt(10) * 10 ** 7] if 'Set3h' in sets_in_plot else [
                math.sqrt(10) * 10 ** 6]
            major_ticks = ax.xaxis.get_ticklocs()
            minor_ticks = ax.xaxis.get_ticklocs(minor=True)
            minor_ticks = [x for x in minor_ticks if (x >= x_limits[0] and x <= x_limits[1] and x <= 10 ** 6)]
            # plt.scatter(break_loc, ypos[0], color='white', marker='s', s=80, clip_on=False, zorder=100)
            # # draw "//" on the same place as text
            for bl in break_loc_x:
                plt.text(bl, ypos[0], r'//', fontsize=10, zorder=101, horizontalalignment='center',
                         verticalalignment='center')
            # ax.set_ylim(ypos)
        # break_loc_y = -0.0002
        # y_axis_break = r"/"
        # plt.text(x_limits[0], break_loc_y+0.00003, y_axis_break, fontsize=10, zorder=101, horizontalalignment='center',
        #          verticalalignment='center', rotation=-45)
        # plt.text(x_limits[0], break_loc_y-0.00003, y_axis_break, fontsize=10, zorder=101, horizontalalignment='center',
        #          verticalalignment='center', rotation=-45)

        y_ticks = ax.yaxis.get_ticklocs()
        y_ticks = [y for y in y_ticks if y >= 0]  # + [-0.0004]
        ax.set_yticks(y_ticks)

        def update_ticks_y(y, pos):
            if y == -0.0004:
                return "NT"
            elif y == 0:
                return 0
            else:
                return round(y * 100, 2)

        def update_ticks_x(x, pos):
            if x == 0:
                return 0
            elif x < 1 and (plot_type == "bell_quality" or plot_type == "link_efficiency"):
                return round(x, 2)
            elif x == 10 ** 7:
                return '(*)'
            elif x == 10 ** 8:
                return '($\ddag$)'
            else:
                return "$10^{" + f"{int(math.log10(x))}" + "}$"

        ax.xaxis.set_major_formatter(mpl.ticker.FuncFormatter(update_ticks_x))
        ax.yaxis.set_major_formatter(mpl.ticker.FuncFormatter(update_ticks_y))

        if plot_type == "bell_succ_prob":
            minor_ticks = ax.xaxis.get_ticklocs(minor=True)
            minor_ticks = [x for x in minor_ticks if (x >= x_limits[0] and x <= x_limits[1])]
            minor_ticks = [400] + minor_ticks
            minor_tick_labels = [
                r"$" + f"{int(t / 100) if t < 1000 else int(t / 1000)}" + r"\cdot10^" + f"{2 if t < 1000 else 3}" + "$" if t in [
                    400, 600, 800, 2000] else "" for t
                in minor_ticks]
            ax.set_xticks(minor_ticks, minor=True)
            # ax.grid(which='minor', color='0.9', linestyle='-', linewidth=0.5)
            # ax.grid(b=True, which='minor', color='r', linestyle='--')

        ax2 = ax.twiny()
        if plot_type == "link_efficiency" or plot_type == "bell_succ_prob":
            ax2.set_xscale('log')

        plot_settings(axes=[ax2], style=plot_style, set_grid=False)
        # ax2.spines["top"].set_visible(False)
        # ax2.spines["right"].set_visible(False)
        # ax2.spines["bottom"].set_visible(False)
        # ax2.spines["left"].set_visible(False)
        ax2.set_xbound(ax.get_xbound())
        if plot_type == "link_efficiency":
            ax.set_xticks(minor_ticks, minor=True)
            ax2.set_xlabel(r'Coherence time scaling factor $f_\mathrm{dec}$', fontsize=font_size_small)
            # ax2.set_xticks([300, 400, 500, 600, 800, 2000, 20000, 200000])
            ax2.set_xticks([400, 500, 600, 800, 2000, 20000, 200000])
            ax2.set_xticks([], minor=True)
            ax2.set_xticklabels([2, "", 3, 4, 10, 100, 1000], fontsize=font_size_small)  # , rotation=45)
        elif plot_type == "bell_succ_prob":
            ax.set_xticklabels(minor_tick_labels, fontsize=font_size_big, minor=True)
            ax2.set_xlabel(r'Total photon detection probability $\eta_\mathrm{ph}$', fontsize=font_size_small)
            ax2.set_xticks(list(sets_in_plot.values()))
            ax2.set_xticks([], minor=True)
            ax2.set_xticklabels(top_ticks_labels, fontsize=font_size_small)
        else:
            ax2.set_xlabel(r'Excitation error probability $p_\mathrm{EE}$', fontsize=font_size_small)
            ax2.set_xticks(list(sets_in_plot.values()))
            ax2.set_xticks([], minor=True)
            ax2.set_xticklabels([f'{xtt:.3f}' for xtt in top_ticks_labels], fontsize=font_size_small)

        # THIS IS NOT IN AGREEMENT WITH THE OTHER FIGURES, I THINK:
        # ax.tick_params(axis='both', labelsize=font_size_small)

        if number_scaling_tries > 1:
            # tight_bbox_raw = ax.get_tightbbox(fig.canvas.get_renderer())
            # print(tight_bbox_raw.x0 / fig.dpi, tight_bbox_raw.x1 / fig.dpi, tight_bbox_raw.y0 / fig.dpi, tight_bbox_raw.y1 / fig.dpi)
            # fig = ax.get_figure()
            w, h = fig.get_size_inches()
            # print(w, h)
            # figure_scale_factor = (tight_bbox_raw.x1 - tight_bbox_raw.x0) / fig.dpi / w

            # if i == number_scaling_tries - 1:
            plt.savefig(f"generate_data_and_figures/figures/plot_{plot_type}_{i}.png", bbox_inches='tight', pad_inches=0)
            im = Image.open(f"generate_data_and_figures/figures/plot_{plot_type}_{i}.png")
            figure_scale_factor = im.size[0] / (w * fig.dpi)
            print(im.size[0], w * fig.dpi, figure_scale_factor)

    if plot_style == "paper":
        plt.savefig("generate_data_and_figures/figures/plot_" + plot_type + ".pdf", bbox_inches='tight', pad_inches=0)
    else:
        plt.savefig("generate_data_and_figures/figures/plot_" + plot_type + ".pdf")
    plt.show()



