%{
Configures and runs [Experiment]s.

The settings below correspond exactly to the code that was executed to create
the results in the paper. By default, instead of running the experiments again,
results are loaded from the `cache/` directory. To re-run all experiments,
change the value of `load_enabled` to `false` below. Since this code is fully
deterministic, it gives the exact same results as in the paper. Make sure you
install the required add-ons listed below, and then run the script. The script
may take several hours to complete, depending on available hardware. On a
high-performance cluster with 32 cores, the script took 3 hours to complete.


%% Configuration
This script runs several [Experiment]s from a [Laboratory]. Each experiment is
described by an [ExperimentConfig], which is constructed from a shared base
configuration, plus the specific experiment's partial configuration.


%% Requirements
%%% Hard requirements
* MATLAB 2021b.
  Newer versions also work, but may not give the exact same results, because some
  functions invoke the random number generator differently depending on the MATLAB
  version.

%%% Recommended
* MATLAB Parallel Computing Toolbox.
  Significantly speeds up computation. You must set `parallel = false` if you do
  not have this toolbox.
* MATLAB Instrument Control Toolbox.
  See below.
* "PARFOR progress monitor (progress bar) v4", v2.0.4, by Frerk Saxen.
  Displays a progress bar GUI during computation.
  Requires both "MATLAB Parallel Computing Toolbox" and "MATLAB Instrument Control
  Toolbox".

%%% Requirements for ML variant
The following requirements apply only if this script's settings are changed to
use machine learning models.

* NVIDIA GPU with CUDA.
  Significantly speeds up training.
  AMD GPUs are not supported by MATLAB.
* MATLAB Deep Learning Toolbox.
* MATLAB Statistics and Machine Learning Toolbox.
* Internet connection.
  Necessary to download datasets during the first run.
%}


    % TODO[R2022b]: Uncomment all `arguments (Output)` blocks


clear;

part_count = str2double(getenv("SLURM_ARRAY_TASK_MAX"));
if isnan(part_count); part_count = 1; end
part_idx = str2double(getenv("SLURM_ARRAY_TASK_ID"));
if isnan(part_idx); part_idx = 0; end


%% Configure
% Laboratory
lab_conf = ...
    LaboratoryConfig.partial( ...
        seed = 17, ...
        load_enabled = true, ...
        ...
        parallel = true, ...
        repeat_count = 100, ...
        ...
        plot_show = true, ...
        plot_save = true, ...
        ...
        part_count = part_count, ...
        part_idx = part_idx ...
    );

% Experiments
exp_base_conf = ...
    ExperimentConfig( ...
        node_count = 50, ...
        dataset_name = "number", ...
        fl_convergence_threshold = 1, ...
        fl_rounds = -1, ...
        metrics_training_loss = false ...
    );
exp_confs = ...
    Config.combinations( ...
        { ...
            ExperimentConfig.partial(network_layout = "erdos-renyi", ...
                                     network_erdos_renyi_p = 0.1), ...
            ExperimentConfig.partial(network_layout = "erdos-renyi", ...
                                     network_erdos_renyi_p = 0.5), ...
            ExperimentConfig.partial(network_layout = "erdos-renyi", ...
                                     network_erdos_renyi_p = 0.9), ...
        }, ...
        arrayfun(@(it) ExperimentConfig.partial(network_min_girth = it), 0:20, UniformOutput = false) ...
    );

% Plots
plot_base_conf = ...
    PlotConfig( ...
        x_label = "Stretched Minimal Girth", ...
        y_label = "Global Rounds Until Convergence", ...
        line_to_label = @(it) conf_to_label(it.conf), ...
        exp_to_data = @(it) [girth_to_y(it.conf.network_min_girth), height(it.metrics.history)] ...
    );
plot_confs = ...
    arrayfun(@(conf) ...
                 PlotConfig.partial( ...
                     title = layout_to_string(conf), ...
                     filter_exp = @(it) it.conf.network_layout == conf.network_layout && ...
                                            it.conf.network_watts_strogatz_k == conf.network_watts_strogatz_k, ...
                     exp_to_line = @(it) hash(it.conf.set(network_min_girth = 1)) ...
                 ), ...
             [ ...
                 ExperimentConfig(network_layout = "erdos-renyi") ...
             ], ...
             UniformOutput = false);


%% Run
exp_confs = num2cell(cellfun(@(it) exp_base_conf.set(it{:}), exp_confs));
plot_confs = num2cell(cellfun(@(it) plot_base_conf.set(it{:}), plot_confs));

lab = Laboratory(LaboratoryConfig(exp_confs = exp_confs, plot_confs = plot_confs).set(lab_conf{:}));
lab.run();
lab.plot();


%% Helpers
function string = layout_to_string(conf)
    if conf.network_layout == "erdos-renyi"
        string = sprintf("erdos-renyi");
    elseif conf.network_layout == "watts-strogatz"
        string = sprintf("watts-strogatz (k=%.0f)", conf.network_watts_strogatz_k);
    elseif conf.network_layout == "barabasi-albert"
        string = sprintf("barabasi-albert");
    else
        error("Unknown network layout '%s'.", conf.network_layout);
    end
end

function string = conf_to_label(conf)
    if conf.network_layout == "erdos-renyi"
        string = sprintf("p=%.1f", conf.network_erdos_renyi_p);
    elseif conf.network_layout == "watts-strogatz"
        string = sprintf("watts-strogatz (p=%.1f)", conf.network_watts_strogatz_p);
    elseif conf.network_layout == "barabasi-albert"
        string = sprintf("barabasi-albert (m=%.0f)", conf.network_barabasi_albert_m);
    else
        error("Unknown network layout '%s'.", conf.network_layout);
    end
end

function y = girth_to_y(girth)
    if girth == -1
        y = 50;
    else
        y = girth;
    end
end
