Meta-Consolidation for Continual Learning (MERLIN)

A reproduction of the paper 'Meta-Consolidation for Continual Learning' by K J Joseph and Vineeth N Balasubramanian, accepted at the proceedings of Neural Information Processing Systems (NIPS 2020). .
Shambhavi Mishra

Reproducibility Summary

In this report, we attempt to reproduce the paper 'Meta-Consolidation for Continual Learning' by K J Joseph et al., accepted in Proceedings of the Neural Information Processing Systems (NIPS 2020). The report covers each aspect of reproducing the results and claims put forth in the paper. This paper proposes a novel methodology for continual learning called MERLIN: Meta-Consolidation for Continual Learning. It was feasible to conduct the reproducibility task, the computation was inexpensive and we could match the results promised in the paper.

Scope of Reproducibility

The paper proposes a novel methodology for continual learning called MERLIN: Meta-Consolidation for Continual Learning. The authors assume that weights of a neural network ψ, for solving task t, come from a meta-distribution p(ψ|t). This meta-distribution is learned and consolidated incrementally. The authors operate in the challenging online continual learning setting, where a data point is seen by the model only once.

Methodology

We used the code provided by the authors in their Github repository. The authors have given the parameters used for the models explicitly in a config file and thus it was easy to reproduce the paper. The code took 2-3 hours for Split MNIST dataset on single NVIDIA GTX 1060 GPU and around 40 minutes on a NVIDIA Tesla P100. Further details are presented in the Run set 4. We used Weights & Biases as our default logging and tracking tool to store all results in a single dashboard.

Results

We reproduced the results for a single dataset, which is Split MNIST as the code was available only for the mentioned dataset. The results obtained overlap with the ones promised in the paper.

What was easy

The paper was understandable and it was quite fascinating to follow its structure. Along with the theoretical concept, the mathematical equations provided ease to reformulate the paper. In addition to this, the authors provided with the original implementation along with the hyperparameters used making it easy to run with very little modifications.

What was difficult

The only bottleneck we faced was the lack of compute resources which prevented us from experimenting with all the variants of Split CIFAR-10, Split CIFAR-100 and Split Mini-ImageNet specified in the paper. Also, the code for baselines could have been added for a comparative analysis illustrated in the paper.

Communication with original authors

The authors was very responsive over email and helped us with every doubt we had during the reimplementation. The reproducibility was encouraged and supported by the authors.

Introduction

This reproducibility submission is an effort to validate the research paper 'Meta-Consolidation for Continual Learning' by K J Joseph et al., accepted in Proceedings of the Neural Information Processing Systems (NIPS 2020).
Continual learning is a machine learning scenario in which a learning model must adapt to new tasks progressively while maintaining its performance on previously acquired tasks.
The authors propose MERLIN: Meta-Consolidation for Continual Learning, a new continuous learning technique based on consolidation in a meta-space, namely the latent space, which generates model weights for solving downstream tasks.
In this reproducibility report, we study MERLIN in detail, which consists of running experiments according to the open-source code by authors, reporting the important details about certain issues encountered during reproducing and comparing the obtained results with the ones reported in the original paper. We report our numbers on seen test accuracy, validation accuracy, loss and average accuracy for each task in the given 'k' tasks in the table and plots below.

Scope of Reproducibility

The authors claim that the weights of a neural network are derived from a meta distribution P( \Psi|t), where 't' is a representation for the task. They propose 'Meta Consolidation', a methodology to learn this distribution, as well as continually adapt it to be competent on new tasks by consolidating this meta-space of model parameters whenever a new task arrives.
Major contributions listed in the paper are :
  1. Proposing a new perspective to continual learning based on the meta-distribution of model parameters and their consolidation.
  2. A method to learn this meta-distribution using a VAE with task-specific priors allowing the ensemble of models for each task at inference.
  3. MERLIN outperforms well known benchmark methods and state-of-the-art method on five continual learning datasets.

Methodology

Understanding the Algorithms

The three step approach to the problem statement as proposed by MERLIN.

MERLIN: Overall Methodology

We consider a sequence of tasks T_{1}, T_{2},.... T_{k-1} that have been seen by the learner, until now. A new task T_{k} is introduced at time instance k. In this step, a set of B base models are trained on random subsets of T_{k}^{tr} to obtain a collection of models \Psi_{k} = [{\Psi^1_{k},......\Psi^B_{k}}]. Using a VAE-like technique, the model (\Psi_{k}) is then utilized to learn a task-specific parameter distribution.

META-CONSOLIDATION IN MERLIN

In the meta-consolidation phase, model parameters are sampled for all tasks seen so far, from the decoder of the VAE, each conditioned on a task-specific prior, and use them to refine the overall VAE.

MERLIN INFERENCE

Inference involves selecting models from parameter distributions for each task and evaluating them against test data.
We can sample any number of models from this distribution at inference/test time since we've learnt the distribution to produce model parameters for each job we've encountered so far. This enables the suggested technique to assemble several models at test time.

Model Architecture

Different architectures have been proposed for different datasets -
These base models (their weights, to be specific) are then used to train the Variational AutoEncoder in MERLIN.

Datasets

The paper utilizes the following benchmark datasets :
The label space increases with tasks in Split datasets, but the data space varies with tasks in Permuted MNIST without altering the label space. The first is known as the Class-Incremental setting, whereas the second is known as the Domain-Incremental setting. The authors claim that MERLIN works in both settings.
We perform reproducibility on the Split MNIST dataset as the code for the same is publicly available.

Hyper-parameters

Throughout all our experiments, we used the same hyper-parameters as in the paper. The authors clearly state all hyper-parameters to train the models in the experiments.

Computational Requirements

To conduct the experiment, we used an NVIDIA Tesla P100 GPU and an NVIDIA GeForce GTX 1060 GPU.

Results

We utilized the baseline values provided by the authors as a .yml file.
task: 'split_mnist' n_tasks: 5 samples_per_task: 1000 validation_samples_per_task: 100 method: run_merlin: True epochs: 1 n_finetune_epochs: 40 learning_rate: 0.1 batch_size_train: 10 batch_size_test: 128 finetune_learning_rate: 0.001
We conducted three runs of the experiment for a robust testing as detailed in the table Run set 4.
The plots illustrated below depict the accuracy, test accuracy, loss and average accuracy for each task in the 5 tasks we reproduced the experiments for.
1. Test Accuracy or 'the average accuracy for each task in tasks' (5 tasks in this case) :
We observed this value to be 82.4 ± 0.7 while in the paper it is given to be 86.6 ± 1.4. The code for computing the test accuracy is attached below for clear understanding.
def test(model, tasks, verbose=False, mode='Test'): accuracies = [] for task in tasks: test_data = MNIST('./data', task=task, mode=mode, transform=transforms.ToTensor()) test_dataloader = DataLoader(test_data, batch_size=cfg.batch_size_test, shuffle=cfg.continual.shuffle_datapoints) accuracy, _ = evaluate_accuracy(model, test_dataloader, task=task) accuracies.append(accuracy) if verbose: log('Accuracy of task %d is %f' % (task, accuracy)) acc = statistics.mean(accuracies) if verbose: log('Average accuracy for ' + str(tasks) + ' is ' + str(acc)) wandb.log({'average accuracy for each task in tasks is': acc}) return acc
2. Number of Parameters in the classifier for Split MNIST dataset exactly overlaps with the figure given in the paper - 89610.

Conclusion

From our attempt at reproducibility, we conclude that MERLIN indeed delivers on the aspects pointed out in the paper. We were able to replicate the main results of the paper which were easy to reproduce. The paper was an interesting read and could be understood easily despite the mathematical complexity as it was well constructed. We would also encourage the original authors to make more of their official code public for other datasets as well.