classdef (Abstract) Model < handle
    % MODEL  A learnable model.

    methods (Abstract)
        loss = train(obj, dataset_train, epochs, mini_batch_size, initial_learn_rate, decay, momentum)
        % TRAIN  Trains this model for [epochs] on [dataset_train], using
        % [mini_batch_size], [initial_learn_rate], [decay], and [momentum], and
        % returning the training [loss].

        combine_with_neighbors(obj, self_weight, neighbor_models)
        % COMBINE_WITH_NEIGHBORS  Adjusts this model to be the weighted average of
        % itself (with weight [self_weight]) and each model in [neighbor_models] (each
        % with a weight of 1).

        accuracy = evaluate_accuracy(obj, dataset_test, mini_batch_size)
        % EVALUATE_ACCURACY  Returns the [accuracy] of this model's predictions over
        % [dataset_test] in mini-batches of size [mini_batch_size].

        flat_learnables = get_flat_learnables(obj)
        % GET_FLAT_LEARNABLES  Returns all learnable parameters of this network as a
        % column vector.
    end

    methods (Static)
        function distance = calc_convergence_distance(models)
            % CALC_CONVERGENCE_DISTANCE  Returns the largest pairwise difference between any
            % two learnable parameters of [models].

            arguments% (Input)
                models (:, 1) cell {mustBeNonempty};  % cell<Model>
            end
            % arguments (Output)
            %     distance (1, 1) double;
            % end

            if numel(models) == 1
                distance = 0;
            else
                all_learnables = cell2mat(cellfun(@(it) it.get_flat_learnables()', models, UniformOutput = false));
                distance = max(max(all_learnables, [], 1) - min(all_learnables, [], 1));
            end
        end
    end
end
