classdef Experiment
    % EXPERIMENT  Runs a fully-decentralised federated learning experiment.
    %
    % An experiment consists of distributing a dataset amongst several nodes, and
    % then, in each round, selecting a node which learns from their neighbours
    % before training their own model a bit, continuing rounds until the nodes'
    % models have converged.
    %
    % See also: ExperimentConfig, PerformanceSim

    properties
        % IDX  This [Experiment]'s unique identifier, containing both its identity and
        % its repetition number.
        idx (1, 2) {mustBeInteger, mustBePositive} = [1, 1];

        % LOGGER  The [Logger] to log with during execution.
        logger (1, 1) Logger;
        % CONF  The configuration on how to run the [Experiment].
        conf (1, 1) ExperimentConfig;
        % METRICS  The metric gatherer.
        metrics (1, 1) Metrics;

        % RAW_GRAPH  The created graph before stretching.
        raw_graph (1, 1) graph;
        % GRAPH  The created graph in which users are configured.
        graph (1, 1) graph;
        % MODELS  The trained (or to-be-trained) models.
        models (:, 1) cell;  % cell<Model>
    end


    methods
        function obj = Experiment(idx, logger, conf)
            % EXPERIMENT  Constructs experiment number [idx] using [logger] and [config].

            arguments% (Input)
                idx (1, 2) {mustBeInteger, mustBePositive};
                logger (1, 1) Logger;
                conf (1, 1) ExperimentConfig;
            end
            % arguments (Output)
            %     obj (1, 1) Experiment;
            % end

            obj.idx = idx;

            obj.logger = logger;
            obj.conf = conf;
            obj.metrics = Metrics(obj.conf.metrics_text, ...
                                  obj.conf.metrics_gui, ...
                                  obj.conf.metrics_gui_by_node, ...
                                  obj.conf.node_count, ...
                                  obj.conf.fl_rounds, ...
                                  obj.conf.fl_convergence_threshold >= 0, ...
                                  obj.conf.metrics_training_loss, ...
                                  obj.conf.metrics_validation_accuracy);
        end


        function obj = run(obj)
            % RUN  Runs this experiment.

            obj.logger.header("Starting experiment #%d-%d.\n", obj.idx(1), obj.idx(2)); total_tic = tic;
            if obj.conf.log_rounds; round_logger = obj.logger; else; round_logger = VoidLogger(); end

            if obj.conf.log_settings
                obj.logger.header("Displaying settings.\n");
                settings_lines = split(formattedDisplayText(obj.conf), newline) + newline;
                obj.logger.print_each(settings_lines(2:end));
                clear settings_lines;
                obj.logger.footer("Done displaying settings.\n");
            end

            round_logger.print("Creating metrics tracker...\n"); tic;
            obj.metrics = obj.metrics.start();
            round_logger.footer("done in %.3f seconds.\n", toc);

            %% Create dataset
            if obj.conf.dataset_name == "mnist"
                download_and_cache(round_logger, ...
                                   "MNIST", ...
                                   "mat", ...
                                   "30MB", ...
                                   "https://lucidar.me/en/matlab/files/mnist.mat");

                round_logger.print("Pre-processing MNIST dataset... "); tic;
                mnist = load("./cache/MNIST.mat");
                mnist.training = rmfield(mnist.training, ["count", "width", "height"]);
                mnist.training.images = permute(mnist.training.images, [3, 1, 2]);
                mnist.test = rmfield(mnist.test, ["count", "width", "height"]);
                mnist.test.images = permute(mnist.test.images, [3, 1, 2]);

                dataset = [struct2table(mnist.training); struct2table(mnist.test)];
                % Roundabout assignment because assigning a cell to a matrix is not possible
                dataset.images = reshape(dataset.images, [], 28 * 28);
                dataset.images = squeeze(num2cell(double(reshape(dataset.images', 28, 28, []) / 256), [1, 2]));
                dataset.labels = categorical(dataset.labels, 0:9);
                dataset = renamevars(dataset, "images", "inputs");

                clear mnist;
                round_logger.append("done in %.3f seconds.\n", toc);
            elseif obj.conf.dataset_name == "emnist"
                download_and_cache(round_logger, ...
                                   "EMNIST", ...
                                   "zip", ...
                                   "709MB", ...
                                   "https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/matlab.zip");

                dataset_processed_path = "./cache/EMNIST.processed.mat";
                if exist(dataset_processed_path, "file")
                    round_logger.print("Loading pre-processed EMNIST dataset... "); tic;
                    dataset = load(dataset_processed_path).dataset;
                    round_logger.append("done in %.3f seconds.\n", toc);
                else
                    round_logger.print("Pre-processing EMNIST dataset... "); tic;

                    dataset = load("./cache/EMNIST/matlab/emnist-digits.mat").dataset;
                    dataset = [struct2table(dataset.train); struct2table(dataset.test)];
                    % Transposed because `num2cell` leaves an unsqueezed dimension otherwise
                    dataset.images = squeeze(num2cell(double(reshape(dataset.images', 28, 28, []) / 256), [1, 2]));
                    dataset.labels = categorical(dataset.labels, 0:9);
                    dataset.writers = categorical(dataset.writers);
                    dataset = renamevars(dataset, ["images", "writers"], ["inputs", "types"]);

                    save(dataset_processed_path, "dataset");
                    round_logger.append("done in %.3f seconds.\n", toc);
                end
            elseif obj.conf.dataset_name == "number"
                numbers = 50 * rand([obj.conf.node_count, 1]);
                dataset = table(numbers, zeros([obj.conf.node_count, 1]), VariableNames=["inputs", "labels"]);
            else
                error("Unknown dataset '%s'.", obj.conf.dataset_name);
            end

            round_logger.print("Splitting dataset... "); tic;
            if ismember(obj.conf.dataset_name, ["mnist", "emnist"])
                [datasets_train, datasets_test] = ...
                    dataset_split(dataset, ...
                                  categorical(0:9), ...
                                  obj.conf.node_count, ...
                                  obj.conf.dataset_samples, ...
                                  obj.conf.dataset_split_fraction, ...
                                  obj.conf.dataset_name == "emnist" && obj.conf.dataset_iid_emnist_by_writer, ...
                                  obj.conf.dataset_iid, ...
                                  obj.conf.dataset_iid_dirichlet_alpha);
            elseif obj.conf.dataset_name == "number"
                datasets_train = arrayfun(@(it) {dataset(it, :)}, (1:obj.conf.node_count)');
                datasets_test = cell([obj.conf.node_count, 1]);
                datasets_test(:) = {dataset};
            else
                error("Unknown dataset '%s'.", obj.conf.dataset_name);
            end
            clear dataset;
            round_logger.append("done in %.3f seconds.\n", toc);


            %% Create network graph
            round_logger.print("Creating network graph... "); tic;
            attempt_idx = 0;
            while true
                attempt_idx = attempt_idx + 1;
                if attempt_idx > obj.conf.network_max_attempts
                    error("Failed to generate network in maximum number of attempts.");
                end

                if obj.conf.network_layout == "erdos-renyi"
                    obj.raw_graph = Graphs.generate_erdos_renyi(obj.conf.node_count, obj.conf.network_erdos_renyi_p);
                elseif obj.conf.network_layout == "watts-strogatz"
                    obj.raw_graph = Graphs.generate_watts_strogatz(obj.conf.node_count, ...
                                                                   obj.conf.network_watts_strogatz_k, ...
                                                                   obj.conf.network_watts_strogatz_p);
                elseif obj.conf.network_layout == "barabasi-albert"
                    obj.raw_graph = Graphs.generate_barabasi_albert(obj.conf.node_count, obj.conf.network_barabasi_albert_m);
                elseif obj.conf.network_layout == "complete"
                    obj.raw_graph = Graphs.generate_complete(obj.conf.node_count);
                elseif obj.conf.network_layout == "empty"
                    obj.raw_graph = Graphs.generate_empty(obj.conf.node_count);
                else
                    error("Unknown network layout '%s'.", obj.conf.network_layout);
                end

                obj.graph = Graphs.stretched(obj.raw_graph, min_girth = obj.conf.network_min_girth);

                if obj.conf.network_require_connected && max(conncomp(obj.graph)) == 1; break; end
            end
            round_logger.append("done in %.3f seconds.\n", toc);


            %% Run federated averaging
            % Initialize models
            round_logger.print("Initializing models... "); tic;
            obj.models = cell(obj.conf.node_count, 1);
            for selected_node = 1:obj.conf.node_count
                if ismember(obj.conf.dataset_name, ["mnist", "emnist"])
                    obj.models{selected_node} = MnistModel();
                elseif obj.conf.dataset_name == "number"
                    obj.models{selected_node} = NumberModel(datasets_train{selected_node}.inputs(1));
                else
                    error("Unknown dataset '%s'.", obj.conf.dataset_name);
                end
            end
            round_logger.append("done in %.3f seconds.\n", toc);

            % Run rounds
            round_idx = 0;
            while true
                if obj.conf.metrics_gui && obj.metrics.monitor.Stop
                    break;
                end
                if obj.conf.fl_rounds >= 0 && round_idx >= obj.conf.fl_rounds
                    break;
                end
                if obj.conf.fl_convergence_threshold >= 0 && ...
                       obj.metrics.get_record("Convergence", round_idx) <= obj.conf.fl_convergence_threshold
                    break;
                end

                round_idx = round_idx + 1;

                % Select node
                selected_node = randsample(obj.conf.node_count, 1);
                obj.metrics.record_round(round_idx, selected_node);
                round_logger.header("Round %d. Selected node: %d (%dº time). Number of nodes never selected: %d.\n", ...
                                    round_idx, ...
                                    selected_node, ...
                                    obj.metrics.get_rounds_recorded_for(selected_node), ...
                                    numel(obj.metrics.get_unrecorded_nodes()));

                % Average
                round_logger.print("Averaging parameters with neighbors... ");
                obj.models{selected_node}.combine_with_neighbors(obj.conf.fl_self_weight, ...
                                                                 obj.models(neighbors(obj.graph, selected_node)));
                round_logger.append("done.\n");

                % Train
                round_logger.print("Training... ");
                training_loss = obj.models{selected_node}.train(datasets_train{selected_node}, ...
                                                                obj.conf.train_epochs_per_round, ...
                                                                obj.conf.train_mini_batch_size, ...
                                                                obj.conf.train_initial_learn_rate, ...
                                                                obj.conf.train_decay, ...
                                                                obj.conf.train_momentum);
                if obj.conf.metrics_training_loss
                    obj.metrics.record_training_loss(round_idx, training_loss);
                end
                round_logger.append("done.\n");

                % Validation
                round_logger.print("Validating current model... ");
                if obj.conf.fl_convergence_threshold >= 0
                    obj.metrics.record_convergence(round_idx, Model.calc_convergence_distance(obj.models));
                end
                if obj.conf.metrics_validation_accuracy
                    accuracy = obj.models{selected_node}.evaluate_accuracy(datasets_test{selected_node}, ...
                                                                           obj.conf.train_mini_batch_size);
                    obj.metrics.record_validation_accuracy(round_idx, accuracy);
                end
                round_logger.append("done.\n");

                % Output
                obj.metrics.update_display(round_logger, round_idx);

                round_logger.footer("Completed round %d.\n", round_idx);
            end


            %% Post-processings
            obj.metrics.gather();

            % Display outputs
            if obj.conf.metrics_text
                obj.logger.header("Displaying collected metrics.\n");
                obj.logger.print_each([newline; split(formattedDisplayText(obj.metrics.history), newline) + newline]);
                obj.logger.footer("Done displaying collected metrics.\n");
            end

            % Done
            obj.logger.footer("Completed experiment #%d-%d in %.3f seconds.\n", obj.idx(1), obj.idx(2), toc(total_tic));
        end
    end
end
