%{
Runs and visualises experiments to determine feasibility of the reconstruction attack. Experiments can be configured
using the settings below.

The settings below correspond exactly to the code that was executed to create the results in the paper. The code is
fully deterministic and gives the exact same results as in the paper. Install the required add-ons listed below, and
then run the script. The script may take several hours to complete, depending on available hardware. On a
high-performance cluster with 32 cores, the script took 2.7 hours to complete.

Alternatively, to analyse and plot the results generated on the cluster, set `enable_save_workspace` to `false` and set
`enable_load_workspace` to true`, and add the file `FeasibilitySim.mat` to the script's working directory. Note that
doing so will ignore all other settings.

If you are running this in a graphical environment, it is recommended to set `enable_progress_bar` and
`enable_heatmaps_display` to `true`. (Note that setting `enable_load_workspace` to `true` overrides these settings.)


%% Requirements
%%% Hard requirements
* MATLAB 2023a.
* "Fast Reduced Row Echelon Form", v1.3.0.0, by Armin Ataei.
  Provides the `frref` function, which is significantly faster than `rref`.
* "Customizable Heat Maps", v1.5.0.1, by Ameya Deoras.
  Creates heatmap figures. Ensure that the heatmap plugin is at the top of
  MATLAB's path, or that the plugin has been downloaded into the current working
  directory. Otherwise, MATLAB's default incompatible `heatmap` function is used.)

%%% Recommended
* MATLAB Parallel Computing Toolbox.
  Significantly speeds up computation. You must set `enable_parallel = false` if
  you do not have this toolbox.
* MATLAB Instrument Control Toolbox.
  See below.
* "PARFOR progress monitor (progress bar) v4", v2.0.4, by Frerk Saxen.
  Required only if `enable_progress_bar` and `enable_parallel` are both `true`.
  Requires both "MATLAB Parallel Computing Toolbox" and "MATLAB Instrument Control
  Toolbox".
%}


%#ok<*UNRCH> Ignore unreachable code due to configuration
clear;


%% Settings
% General
... `true` if a progress bar should be displayed while running experiments.
enable_progress_bar = false;
... `true` if experiments should be run in parallel.
enable_parallel = true;

% Workspace
... `true` if the workspace should be stored after running all experiments
... Mutually exclusive with `enable_load_workspace`
enable_save_workspace = false;
... `true` if an old workspace should be loaded instead of running new experiments
... Mutually exclusive with `enable_save_workspace`
enable_load_workspace = false;
... The filename to save/load the workspace to/from if enabled
save_workspace_filename = "FeasibilitySim.mat";

% Graph properties
... `true` to include graphs with more than one connected component
include_forests = true;
... `true` to include graphs with colluders that have no edges
include_unused_colluders = true;
... `true` to include graphs with neighbours that have no edges
include_unused_neighbours = false;
... `true` to include graphs in which two adversaries differ in exactly one neighbour
include_trivial_attacks = true;

... Number(s) of colluders to run experiments for
colluders = 3:2:7;
... Number(s) of honest-but-curious neighbours to run experiments for
neighbours = 3:15;
... Number of graphs to try for each graph configuration
... This number excludes graphs that do not satisfy other settings' requirements
graphs_per_config = 1000;
... Maximum number of graphs that do not satisfy settings before giving up
max_graph_attempts = 1000;

% Adversarial knowledge
... `true` if experiment should measure how many summations adversaries need
simulate_adversarial_knowledge = true;

... Number of different run orders to try out in each solvable graph
collusions_per_graph = 100;
... Maximum number of adversarial summations to simulate before giving up
rounds_max = 250;
... Number of rounds to execute between checking for partial solutions
rounds_step = 10;

% Visualisation
... `true` if figures should be displayed afterwards
... Tip: Close all figures by running `close all`
enable_figures_display = false;
... Vector of image formats to save figures as, or an empty list not to save images
... See documentation of `saveas` for list of supported formats
figures_save_formats = "epsc";
... Relative path to store created figures under, if `enable_figures_save` is `true`
figures_directory = "./fig/";

... `true` if visualisation should display `nan` for 0 values
display_0_as_nan = false;


%% Validate settings
if enable_load_workspace && enable_save_workspace
    error("Settings `enable_load_workspace` and `enable_save_workspace` are mutually exclusive.");
    return;
end
if enable_progress_bar && enable_parallel && exist("ParforProgressbar", "class") == 0
    error("Settings `enable_progress_bar` and `enable_parallel` are `true` but `ParforProgressbar` is not installed.");
    return;
end

%% Experiments
if enable_load_workspace
    disp("# Loading stored workspace");
    load(save_workspace_filename);
else
    %% Initialise parallellism
    if enable_parallel
        disp("# Initialising parallellism");
        feature("numcores");

        poolobj = gcp("nocreate");
        if isempty(poolobj)
            tic;
            poolobj = parpool("local", feature("numcores"));
            time = toc;
            disp("Created parallel pool with " + poolobj.NumWorkers + " workers in " + time + " seconds.");
        else
            disp("Using existing pool with " + poolobj.NumWorkers + " workers.");
        end
    end


    %% Run experiments
    % An experiment is one choice of parameters k, n, e. Each experiment consists of graphs_per_config random graphs
    % with those parameters. Over each graph, collusions_per_graph random choices of colluders are simulated.

    % Set up outputs
    experiment_size = [max(colluders), max(neighbours), max(colluders) * max(neighbours)];
    experiment_count = prod(experiment_size);

    % Metrics per experiment
    ... For each graph in each experiment, ratio of private data that can be reconstructed
    data_reconstructed_ratio = nan(experiment_count, graphs_per_config);
    ... For each collusion for each "solvable" graph for each experiment, number of rounds before reconstruction
    ... succeeds
    collusions_rounds_count = nan(experiment_count, graphs_per_config, collusions_per_graph);

    % Start experiments
    disp("# Running experiments. This may take a while");
    if enable_progress_bar
        if enable_parallel
            pbar = ParforProgressbar(experiment_count);
            pbar_cleanup = onCleanup(@(it) delete(pbar));
        else
            pbar = waitbar(0, "0%");
            pbar_cleanup = onCleanup(@(it) close(pbar));
        end
    end

    tic;
    parfor (experiment_idx = 1:experiment_count, enable_parallel * feature("numcores"))
        rng(experiment_idx, "twister");

        if enable_progress_bar
            if enable_parallel
                pbar.increment();
            elseif mod(experiment_idx, 100) == 0 || experiment_idx == experiment_count
                progress = 1 - experiment_idx / experiment_count;
                waitbar(progress, pbar, (100 * progress) + "%");
            end
        end

        % `parfor` requires single linear index
        % k = #colluders, n = #neighbours, e = #edges
        [k, n, e] = ind2sub(experiment_size, experiment_idx);
        if ~any(k == colluders) || ~any(n == neighbours) || e > k * n
            continue;
        end

        % Check shortcuts for configurable requirements
        if ~include_unused_colluders && e < 2 * k; continue; end
        if ~include_unused_neighbours && e < n; continue; end
        if ~include_forests && e < k + n - 1; continue; end

        % Create graphs
        for graph_idx = 1:graphs_per_config
            graph_attempt = 0;
            while graph_attempt < max_graph_attempts
                graph_attempt = graph_attempt + 1;

                % Generate random biadjacency matrix with `e` 1s in it
                B = [ones(1, e), zeros(1, k * n - e)];
                perm = randperm(k * n);
                B = reshape(B(perm), k, n);
                A = [zeros(size(B, 1)), B; B', zeros(size(B, 2))];
                G = graph(A);

                % Skip if any colluder has one neighbour
                if any(sum(B, 2) == 1); continue; end

                % Check configurable requirements
                if ~include_unused_colluders && any(sum(B, 2) == 0); continue; end
                if ~include_unused_neighbours && any(sum(B, 1) == 0); continue; end
                if ~include_trivial_attacks && any(pdist(A, "cityblock") == 1); continue; end
                if ~include_forests && ~graph_is_connected(A); continue; end

                % Valid graph found
                graph_attempt = max_graph_attempts;
                data_reconstructed = sum(sum(abs(frref(B)), 2) == 1);
                data_reconstructed_ratio(experiment_idx, graph_idx) = data_reconstructed / n;

                % Simulate adversarial knowledge if at least one datum can be reconstructed
                if data_reconstructed > 0 && simulate_adversarial_knowledge
                    for collusion_idx = 1:collusions_per_graph
                        knowledge = zeros(rounds_max, n * rounds_max);

                        n_has_updated = true(1, n);  % Whether neighbour has updated since last colluder summation
                        n_indices = zeros(1, n);  % How manieth private value is used by neighbour
                        n_indices_starts = ((1:n) - 1) * rounds_max + 1;  % Current column in `knowledge` of each
                                                                          % neighbour

                        t = 1;
                        while t <= rounds_max
                            active_user = randi([1, k + n]);
                            if active_user <= k
                                % Colluder selected
                                active_neighbours = find(B(active_user, :));

                                n_indices(1, active_neighbours) = ...
                                    n_indices(1, active_neighbours) + n_has_updated(1, active_neighbours);
                                n_has_updated(1, active_neighbours) = false;

                                active_n_indices = n_indices_starts(active_neighbours) + n_indices(active_neighbours);

                                knowledge(t, active_n_indices) = 1;
                                t = t + 1;

                                % Check for partial solution
                                if mod(t, rounds_step) == 0 || t == rounds_max
                                    knowledge_nonzero_col_idxs = ...
                                        cell2mat(arrayfun(@colon, n_indices_starts, n_indices_starts + n_indices, ...
                                                          "UniformOutput", false));
                                    knowledge_no_zero_rows_or_cols = knowledge(1:(t - 1), knowledge_nonzero_col_idxs);

                                    if any(sum(abs(frref(knowledge_no_zero_rows_or_cols)), 2) == 1)
                                        collusions_rounds_count(experiment_idx, graph_idx, collusion_idx) = t;
                                        break;
                                    end
                                end
                            else
                                % Neighbour selected
                                n_has_updated(1, active_user - k) = true;
                            end
                        end
                    end
                end
            end
        end
    end
    time = toc;
    disp("Completed in " + time + " seconds.");

    if enable_save_workspace
        save(save_workspace_filename, "-v7.3");
    end
end


%% Post-process results
disp("# Post-processing results")
% De-linearize indices
data_reconstructed_ratio = reshape(data_reconstructed_ratio, [experiment_size, graphs_per_config]);
collusions_rounds_count = reshape(collusions_rounds_count, [experiment_size, graphs_per_config, collusions_per_graph]);

% Calculate mean and std
data_reconstructed_ratio_mean = mean(data_reconstructed_ratio, 4, "omitnan");
data_reconstructed_ratio_std = std(data_reconstructed_ratio, 0, 4, "omitnan");
collusions_rounds_count_mean = mean(collusions_rounds_count, [4, 5], "omitnan");


%% Output summary
disp("# Summary");
for k = colluders
    fprintf("k=%d: reconstructed (mean): %.4f%%\n", ...
            k, ...
            mean(data_reconstructed_ratio_mean(k, :, :) * 100, "all", "omitnan"));
    fprintf("k=%d: reconstructed (std): %.4fpp\n", ...
            k, ...
            mean(data_reconstructed_ratio_std(k, :, :) * 100, "all", "omitnan"));

    if simulate_adversarial_knowledge
        fprintf("k=%d: rounds (mean): %d\n", k, mean(collusions_rounds_count_mean(k, :, :), "all", "omitnan"));
        fprintf("k=%d: rounds (std): %d\n", k, std(collusions_rounds_count_mean(k, :, :), 0, "all", "omitnan"));
    end
end


%% Visualisation
if display_0_as_nan
    data_reconstructed_ratio_mean(data_reconstructed_ratio_mean == 0) = nan;
    data_reconstructed_ratio_std(data_reconstructed_ratio_std == 0) = nan;
    collusions_rounds_count_mean(collusions_rounds_count_mean == 0) = nan;
end

if enable_figures_display || ~isempty(figures_save_formats)
    disp("# Visualising results");
    if ~isfolder(figures_directory); mkdir(figures_directory); end

    for k = colluders
        % Plot reconstructed data count means
        ... Process data
        xlabels = neighbours;
        ylabels = 1:(k * max(neighbours));
        raw = (100 * squeeze(data_reconstructed_ratio_mean(k, xlabels, ylabels)))';
        notnan_cols = ~all(isnan(raw), 1);
        nonnan_rows = ~all(isnan(raw), 2);
        raw = raw(nonnan_rows, notnan_cols);
        xlabels = xlabels(notnan_cols);
        ylabels = ylabels(nonnan_rows);

        ... Plot heatmap
        figure("Visible", enable_figures_display);
        heatmap(raw, xlabels, ylabels, ...
                '%.0f%%', ...
                "Colorbar", true, ...
                "FontSize", 0, ...
                "NaNColor", [0 0 0]);
        xlabel("Number of direct neighbours");
        ylabel("Number of bipartite edges");
        for format = figures_save_formats
            saveas(gcf, figures_directory + "/data-reconstructed-" + k, format);
        end
        title("data-reconstructed-" + k);


        % Plot reconstructed data count standard deviations
        ... Create 2D histogram bins
        raw = reshape(data_reconstructed_ratio(k, neighbours, :, :), numel(neighbours), [])';
        ns = repmat(neighbours, size(raw, 1), 1);
        raw = raw .* ns;
        [bins, xlabels, ylabels] = histcounts2(ns, raw);

        ... Post-process bins
        bins = flipud(bins');
        bins = 100 * (bins ./ sum(bins, 1));
        xlabels = movmean(xlabels, 2, "Endpoints", "discard");
        ylabels = movmean(fliplr(ylabels), 2, "Endpoints", "discard");

        ... Plot heatmap
        figure("Visible", enable_figures_display);
        heatmap(bins, xlabels, ylabels, ...
                '%.0f%%', ...
                "Colorbar", true, ...
                "ShowAllTicks", true, ...
                "FontSize", 0);
        xlabel("Number of direct neighbours");
        ylabel("Number of reconstructed data");
        for format = figures_save_formats
            saveas(gcf, figures_directory + "/data-reconstructed-frequency-" + k, format);
        end
        title("data-reconstructed-frequency-" + k);


        % Plot round counts
        ... Process data
        xlabels = neighbours;
        ylabels = 1:(k * max(neighbours));
        raw = (squeeze(collusions_rounds_count_mean(k, xlabels, ylabels)))';
        notnan_cols = ~all(isnan(raw), 1);
        nonnan_rows = ~all(isnan(raw), 2);
        raw = raw(nonnan_rows, notnan_cols);
        xlabels = xlabels(notnan_cols);
        ylabels = ylabels(nonnan_rows);

        ... Plot heatmap
        if simulate_adversarial_knowledge
            figure("Visible", enable_figures_display);
            colormap(flipud(colormap("autumn")));
            heatmap(raw, xlabels, ylabels, ...
                    [], ...
                    "Colorbar", true, ...
                    "FontSize", 0, ...
                    "NaNColor", [0 0 0]);
            xlabel("Number of direct neighbours");
            ylabel("Number of bipartite edges");
            for format = figures_save_formats
                saveas(gcf, figures_directory + "/collusion-rounds-" + k, format);
            end
            title("collusion-rounds-" + k);
        end
    end
end


%% Done
disp("# Done");
