"""
2022-2024 Sebastian de Bone (QuTech)
https://github.com/sebastiandebone/GHZ_prot_II
_____________________________________________
"""
import os
import time
import argparse
from datetime import datetime
import pickle
import sys
sys.path.insert(1, '.')

import GHZ_prot_II.da_search as das
import GHZ_prot_II.auxiliaries_and_help_files.calculate_attempts_per_echo as cape
import GHZ_prot_II.simulate_protocol.run_auto_generated_protocol as ragp
from utilities.files import get_full_path

ACCURACY = 15


class LoadFromFile(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        with values as v:
            parser.parse_args([argument for argument in v.read().split() if "#" not in argument], namespace)


def parse_arguments():
    parser = argparse.ArgumentParser(description='Calculate protocol recipes.')

    parser.add_argument('-n_max', required=False, type=int, default=4,
                        help='An integer describing the maximum number of parties involved in the search.')
    parser.add_argument('-k_max', required=False, type=int, default=4,
                        help='An integer describing the maximum number of Bell pairs involved in the search.')
    parser.add_argument('-nstate', required=False, type=int, default=10,
                        help='An integer describing the buffer size involved in the search.')
    parser.add_argument('-number_id_elems', required=False, type=int, default=1,
                        help='An integer describing how many (random) versions of the order in which Bell pairs are '
                             'created are taken into account.')

    parser.add_argument('-input_state', required=False, type=str, default="0.9",
                        help='A string describing the input Bell pair fidelity involved in the search. If this is set '
                             'to a float between 0 and 1, the Bell pairs are set to Werner states with this fidelity.')

    parser.add_argument('-da_type', required=False, type=str, default='mpuF',
                        help='A string describing the algorithm type involved in the search.')
    parser.add_argument('-rot_type', required=False, type=str, default='none',
                        help='A string describing the type of rotations of ancillary state involved in the search.')
    parser.add_argument('-calculate_prot', required=False, type=str, default='True',
                        help='A boolean describing whether or not noisy versions of the protocols should be evaluated'
                             'during the search.')

    parser.add_argument('-input_data_file', required=False, type=str, default=None,
                        help='Filename of data structure that is used as the basis for this search.')
    parser.add_argument('-n_start', required=False, type=int, default=2,
                        help='An integer describing for which n value the search should start.')
    parser.add_argument('-k_start', required=False, type=int, default=1,
                        help='An integer describing for which k value the search should start.')

    parser.add_argument('-number_of_threads', required=False, default='auto',
                        help='A parameter describing how many threads can be used simultaneously in case '
                             '-calculate_prot is set to True.')
    parser.add_argument('-filename_prefix', required=False, type=str, default='',
                        help='A parameter describing how the pickle file with results should start.')
    parser.add_argument('-failed_protocol_prefix', required=False, type=str, default='Protocol_failed',
                        help='The folder in which we place failed protocols.')
    parser.add_argument('-max_calc_time', required=False, type=str, default=None,
                        help='A string in the format DD-HH:MM:SS or HH:MM:SS describing the maximum calculation time in number of hours.')
    parser.add_argument('-seeds', required=False, type=int, nargs="*",
                        help='An integer or string describing how many seeds of which seeds number should be used to '
                             'average over in case -calculate_prot is True.')
    parser.add_argument('-init_seed_search', required=False, type=str, default='True',
                        help='A boolean describing whether or not part of the seeds should first be used to get a first '
                             'idea of a protocol (and see if there is a chance that it is better than what we have in '
                             'the buffer so far), before the remaining seeds are used.')
    parser.add_argument('-skip_similar_protocols', required=False, type=str, default='identical',
                        help='A string describing if a protocol should be considered (i.e., considered in the search) '
                             'if it is either \'identical\' or \'similar\' (in its binary tree structure) to a '
                             'protocol that already has been considered in the search. If one wants to not skip '
                             'protocols, it is sufficient to pass \'None\', \'no\' or \'n\'.')
    parser.add_argument('-reshuffle_protocols', required=False, type=str, default='True',
                        help='A boolean describing whether the protocol should be \'sliced\' before given to the '
                             'available CPU cores (in order to have a more \'balanced\' distribution of them over the '
                             'available CPUs.')
    parser.add_argument('-p_g', required=False, type=float, default=None,
                        help='A float describing what the (two-qubit) gate and measurement error probability for the '
                             'protocols in the search should be.')
    parser.add_argument('-alpha', required=False, type=float, default=None,
                        help='A float describing what the bright-state population for the single-click protocol '
                             'should be.')
    parser.add_argument('-n_DD', required=False, type=str, default='None',
                        help='An integer describing how many entanglement generation attempts fit in half of a '
                             'dynamical decoupling sequence, or a string "auto" indicating that the program should'
                             'automatically find the best value for n_DD based on the other Bell state parameters'
                             'given.')
    parser.add_argument('-compare_metric', required=False, type=str, default='None',
                        help='String describing what metric we want to use to determine if one protocol is better than '
                             'another protocol: None will use the default setting, "ghz_fid" will use the fidelity of '
                             'the GHZ state, "stab_fid" will use the stabilizer fidelity, and "weighted_sum" will use '
                             'the weighted sum of the superoperator. Using a metric "weighted_sum_max" will only use '
                             '"weighted_sum" on (n_max,k_max), and "stab_fid" for protocols before that; this setting '
                             'will also overwrite a pre-calculated (n_max,k_max) buffer that was included with the '
                             'input_data_file. Further, the metric "logical_success" will compare protocols by looking '
                             'at the logical success rate over 5000 of their superoperator for a distance 4 toric '
                             'code. This will, however, only work for n=4: for all other values of n, this setting '
                             'willdefault back to using "stab_fid" as compare metric. If the "logical_success" setting '
                             'is active, "-init_seed_search" is set to False.')

    parser.add_argument('-argument_file',
                        help="Loads values from a file instead of the command line.",
                        type=open,
                        action=LoadFromFile)

    args_vars = vars(parser.parse_args())
    if args_vars['seeds'] and len(args_vars['seeds']) == 1:
        args_vars['seeds'] = args_vars['seeds'][0]
    if args_vars['number_of_threads'] != 'auto':
        args_vars['number_of_threads'] = int(args_vars['number_of_threads'])
    if args_vars['n_DD'] not in ['None', 'none', 'auto']:
        args_vars['n_DD'] = float(args_vars['n_DD'])
    for arg in ['calculate_prot', 'init_seed_search', 'reshuffle_protocols']:
        args_vars[arg] = True if args_vars[arg].lower() in ['true', 'yes', '1', 'y', 't'] else False

    return args_vars


if __name__ == "__main__":
    args_vars = parse_arguments()

    t = time.time()
    n_max = args_vars['n_max']
    k_max = args_vars['k_max']
    nstate = args_vars['nstate']
    nmb_id_elems = args_vars['number_id_elems']
    input_state = args_vars['input_state']
    da_type = args_vars['da_type']
    rot_type = args_vars['rot_type']
    calculate_prot = args_vars['calculate_prot']
    number_of_threads = args_vars['number_of_threads']
    seeds = args_vars['seeds']
    init_seed_search = args_vars['init_seed_search']
    skip_similar_prots = args_vars['skip_similar_protocols']
    reshuffle_protocols = args_vars['reshuffle_protocols']
    input_data_file = args_vars['input_data_file']
    p_g = args_vars['p_g']
    alpha = args_vars['alpha']
    n_DD = args_vars['n_DD']
    metric = args_vars['compare_metric'] if args_vars['compare_metric'] != 'None' else None
    max_calc_time = args_vars['max_calc_time']
    if max_calc_time is not None:
        max_calc_time_elements = max_calc_time.split(":")
        if len(max_calc_time_elements) == 3:
            numb_hours = max_calc_time_elements[0]
            numb_days = 0
            if len(numb_hours.split("-")) == 2:
                numb_days = int(numb_hours.split("-")[0])
                numb_hours = int(numb_hours.split("-")[1])
            numb_hours = int(numb_hours)
            if numb_hours > 23:
                numb_days += int(int(numb_hours) / 24)
                numb_hours -= int(int(numb_hours) / 24) * 24
            max_calc_time = f"{numb_days + 1}-{numb_hours}:{max_calc_time_elements[1]}:{max_calc_time_elements[2]}"
        elif len(max_calc_time_elements) == 1:
            max_calc_time = f"1-00:00:{max_calc_time_elements[0]}"
        elif len(max_calc_time_elements) == 2:
            max_calc_time = f"1-00:{max_calc_time_elements[0]}:{max_calc_time_elements[1]}"
        elif len(max_calc_time_elements) == 3:
            max_calc_time = f"1-{max_calc_time_elements[0]}:{max_calc_time_elements[1]}:{max_calc_time_elements[2]}"
        else:
            raise ValueError("max_calc_time is not understood.")

        print(f"Time limit is set at: {max_calc_time}, where the day number should be subtracted by 1.")

        date_time = datetime.strptime(max_calc_time, "%d-%H:%M:%S")
        a_timedelta = date_time - datetime(1900, 1, 1)
        max_calc_time_in_seconds = a_timedelta.total_seconds()

        print(f"This is converted to {max_calc_time_in_seconds} seconds.")

    if n_DD == 'auto' and isinstance(input_state, str):
        bell_pars = ragp.simulate_protocol_recipe(None, set_name=input_state, return_bell_par=True, alpha=alpha)
        if 'p_link' in bell_pars and bell_pars['p_link'] is not None:
            p_link = bell_pars['p_link']
            F_link = bell_pars['F_link']
        else:
            p_link, F_link = cape.calculate_bell_state_properties(bell_pars)
        n_DD = cape.number_of_entanglement_attempts_per_echo_pulse(p_link=p_link,
                                                                   t_link=bell_pars['t_link'],
                                                                   t_pulse=bell_pars['t_pulse'])
        print(f"The parameter 'n_DD' is automatically set at {n_DD}, for p_link = {p_link}, "
              f"t_link = {bell_pars['t_link']} and t_pulse = {bell_pars['t_pulse']}, with F_link = {F_link}")

    failed_protocol_prefix = args_vars['failed_protocol_prefix']
    if input_data_file is not None:
        n_start, k_start = args_vars['n_start'], args_vars['k_start']
    else:
        n_start, k_start = 2, 1

    failed_protocol_prefix = "" if failed_protocol_prefix is None \
        else (failed_protocol_prefix if failed_protocol_prefix[-1] != "/" else failed_protocol_prefix[:-1])
    main_folder = f"results/protocols/{failed_protocol_prefix}/"
    if not os.path.exists(main_folder):
        os.mkdir(main_folder)

    print(f"Calculating dynamic algorithm with the following parameters: \n{args_vars}.")
    data, number_of_prots = das.dynamic_algorithm(n_max, k_max, input_state, input_data_file, n_start, k_start,
                                                  da_type=da_type, nstate=nstate, rot_type=rot_type,
                                                  show_or_not=1,
                                                  calculate_prot=calculate_prot, nmb_id_elems=nmb_id_elems,
                                                  number_of_threads=number_of_threads, seeds=seeds,
                                                  failed_protocol_prefix=failed_protocol_prefix,
                                                  init_seed_search=init_seed_search,
                                                  skip_similar_prots=skip_similar_prots,
                                                  reshuffle_protocols=reshuffle_protocols,
                                                  p_g=p_g, alpha=alpha, n_DD=n_DD, metric=metric)
    print("")
    for k in range(n_max - 1, k_max + 1):
        print(f"Maximal metric for (n, k) = ({n_max}, {k}) is {data[n_max][k][0].state}.")
    print(f"Elapsed time is {time.time() - t}.")
    print(f"Number of protocols considered = {number_of_prots}.")
    file_name_addition = "_" + str(nmb_id_elems) if calculate_prot else ""
    file_name_total = f"{args_vars['filename_prefix']}data_{n_max}_{k_max}_{nstate}" + file_name_addition + \
                      f"_{input_state}_{seeds}_{da_type}_algv1.pkl"
    get_full_path(file_name_total, strip_filename=True)
    pickle.dump(data, open(file_name_total, "wb"))