A Brief Introduction to Continual Learning

A review of some of the concepts, challenges, and approaches in continual learning, with examples and a colab. Made by Shambhavi Mishra using W&B
Shambhavi Mishra

Table of Contents

Introduction

Have you ever found it difficult to tell an apple from an orange?
For most of us, the answer is of course "no, absolutely not." Learning what objects are and remembering them is one the things people are hard-wired to do. We're experts at recalling what we've encountered more than once.
And while state-of-the-art AI Models are better than humans at certain tasks (everything from playing Atari games to organizing big logistics chains), they lack the ability to retain previously acquired knowledge the way we do.
But wait, the blog title reads 'Continual Learning'. How does that come to our rescue?
Continual Learning involves a continuous stream of data and helps us both:
To understand these challenges in depth, we need to first understand two important concepts: Distributional Shifts (a.k.a. Concept Drift) and Catastrophic Forgetting. We'll start with the former:

Distributional Shift or Concept Drift

First, as a reminder: we are now dealing with a continuously changing incoming stream of data instead of our usual collected or curated dataset.
Now, when we talk about data, we also talk about the distribution of that data. As an example, let's consider the common probability distribution of data (a normal distribution). A lot of data in our day to day life, such as height of people in a given region, follows normal distribution.
However, if our data is incremental in nature, we can not expect a stationary distribution. When the distribution changes, a Continual Learning model has to understand it and adapt to overcome forgetting. Such changes in the distribution are called distributional shifts or concept drift.
Further we can have two different shifts, a virtual concept drift which occurs due to imbalanced distribution of data among classes and a real concept drift which is caused by the addition of a new data point or class in our dataset.
Some examples of drift might be seasonality or words accruing new meanings via slang over time.

Catastrophic Forgetting

Catastrophic Forgetting, meanwhile, is just the forgetting experienced by a neural network of previously learnt concepts as it is trained sequentially on new concepts [1]. In other words, a model that knew a thing once "forgets" it upon learning something new.

Approaches to Continual Learning

There are a few main approaches to Continual Learning and we'll cover each in some detail below. The methods:
Let's start with the first!

Replay Methods

Do you remember taking practice tests before actual exams in school?
The principle of replay-based methods is to revisit previously acquired knowledge - yes, pretty much like how we recall lesson notes or check flash cards before a big test.
In replay-based approaches, we ensure that we store prior experiences. The samples might be random or selected to represent the features of the class. Alternatively, the creation (read generation) of examples has also been researched with the help of generative modeling.
Recently, replay or rehearsal based methods have gained popularity in the domain of Continual Learning. Let's look at CoPE : Continual Prototype Evolution: Learning Online from Non-Stationary Data Streams [4] accepted at ICCV 2021.
Figure 1 : Architecture for CoPE [4] - The learner updates network f_θ and prototypes p^y, ∀ y ∈ Y continually. The PPP-loss encourages inter-class variance (red arrows) and reduces intra-class variance (green arrows).
The paper proposed to utilize the concept of Representation Learning to represent the classes through a prototype.
As illustrated in the Figure 1, a single batch from the stream is processed. There are three components that comprise CoPE:
  1. Continual Evolution of the Prototypes: The authors elaborate that one possible drawback is that the chosen representatives might become irrelevant or stale as the representation space continuously changes. To overcome this challenge, they propose the online batch-wise updates with high momentum (selected manually in a similar fashion to how we select hyperparameters). This technique is also low on resource usage.
  2. PPP (Pseudo-Prototypical Proxy) Loss: This component is our answer to "How do we ensure that the prototypes are always relevant to what's being updated in our data stream?" The key idea is to consider the relationship between representations for a concept. The authors define an attractor P (c|x^c_i) and repellor set P_i(c|x^k_j) for each instance x_i^c in the batch is predicted as class c and instances x^k_j not being predicted as class c. Loss is defined by taking the negative log-likelihood and summing over all the instances.
    L = -1/|B| [\sum_{i} log P(c/x_i^c) + \sum_{i}\sum_{x^k_j} log(1 - P_i(c/x^k_j))]
  3. Balanced Replay: Coming back to where we started! Yes, after selecting the representatives and ensuring that they are always relevant, each class was given equal importance by dividing the memory equally. Equal amount of samples were stored for the classes and then the batches were sampled randomly from the buffer.
If you want to explore CoPE more checkout this presentation or the github link!

Regularization-Based Methods

Does it ring a bell? Yes! This is the same method we have been using to prevent overfitting of models to the datasets.
We can also use simple methods like Dropout and Early Stopping to avoid catastrophic forgetting as these methods help in dealing with changing weights. But we have an understanding that some weights are more important than the rest, so won't it be wise to store them to avoid forgetting?
The Elastic Weight Consolidation (EWC) approach [5] proposed by Kirkpatrick et al. (2017) uses a Fisher Information Matrix to estimate the importance of weights.
You can train a basic ANN on MNIST for a few epochs (we did it for 3 epochs for the results illustrated below) and then the same network is to be trained on another dataset (Fashion MNIST in this case) with EWC method.
We use the implementation of EWC from here.
class ElasticWeightConsolidation: def __init__(self, model, crit, lr=0.001, weight=1000000): self.model = model self.weight = weight self.crit = crit self.optimizer = optim.Adam(self.model.parameters(), lr) def _update_mean_params(self): for param_name, param in self.model.named_parameters(): _buff_param_name = param_name.replace('.', '__') self.model.register_buffer(_buff_param_name+'_estimated_mean', param.data.clone()) def _update_fisher_params(self, current_ds, batch_size, num_batch): dl = DataLoader(current_ds, batch_size, shuffle=True) log_liklihoods = [] for i, (input, target) in enumerate(dl): if i > num_batch: break output = F.log_softmax(self.model(input), dim=1) log_liklihoods.append(output[:, target]) log_likelihood = torch.cat(log_liklihoods).mean() grad_log_liklihood = autograd.grad(log_likelihood, self.model.parameters()) _buff_param_names = [param[0].replace('.', '__') for param in self.model.named_parameters()] for _buff_param_name, param in zip(_buff_param_names, grad_log_liklihood): self.model.register_buffer(_buff_param_name+'_estimated_fisher', param.data.clone() ** 2) def register_ewc_params(self, dataset, batch_size, num_batches): self._update_fisher_params(dataset, batch_size, num_batches) self._update_mean_params() def _compute_consolidation_loss(self, weight): try: losses = [] for param_name, param in self.model.named_parameters(): _buff_param_name = param_name.replace('.', '__') estimated_mean = getattr(self.model, '{}_estimated_mean'.format(_buff_param_name)) estimated_fisher = getattr(self.model, '{}_estimated_fisher'.format(_buff_param_name)) losses.append((estimated_fisher * (param - estimated_mean) ** 2).sum()) return (weight / 2) * sum(losses) except AttributeError: return 0 def forward_backward_update(self, input, target): output = self.model(input) loss = self._compute_consolidation_loss(self.weight) + self.crit(output, target) self.optimizer.zero_grad() loss.backward() self.optimizer.step()
Wanna explore how to visualize your results? Run this COLAB NOTEBOOK.
(To log it with W&B, add these magical lines)
! pip install wandbimport wandbwandb.login()config = dict ( learning_rate = 0.0001)wandb.init( project="ewc", config=config)# code to train your model# time to test the performanceAccuracy_Fashion_MNIST = accu(ewc.model, f_test_loader)Accuracy_MNIST = accu(ewc.model, test_loader)# log your important metrics using this simple line of code!wandb.log({'Accuracy of Fashion_MNIST' : Accuracy_Fashion_MNIST}) wandb.log({'Accuracy of MNIST' : Accuracy_MNIST})
A question for you to explore on your own : How does regularization based methodology perform in the case of domain shift?

Architecture Based Approaches

We can also modify the architecture to overcome forgetting. One of the ways it to nest multiple networks, one for each task, in a global network as illustrated below.
A possible architecture based solution to forgetting. Source.
Another approach is to create a new model, connected to all the past models, for a new task. A hard attention process of masking important weights to freeze them has also been proposed. Dual approaches have also been defined for CL where one model learns new concepts while the other one is used as a storehouse of all the previously acquired experiences.

Conclusion

In this blog, we explored the possibility of reaching a step closer to AGI with Continual Learning. We began with basic yet important concepts of catastrophic forgetting and concept drift. Moving forward, we understood the research directions in CL and also read about the state-of-the-art in them.

References

  1. McCloskey, M. and Cohen, N. J. (1989). Catastrophic interference in connectionist networks: The sequential learning problem. In Psychology of learning and motivation, volume 24, pages 109–165. Elsevier.
  2. A continual learning survey: Defying forgetting in classification tasks.
  3. iCaRL: Incremental Classifier and Representation Learning
  4. CoPE : Continual Prototype Evolution: Learning Online from Non-Stationary Data Streams
  5. Overcoming catastrophic forgetting in neural networks
  6. Normal Distribution
  7. Continual Learning: Tackling Catastrophic Forgetting in Deep Neural Networks with Replay Processes
  8. Github Repository for EWC : https://github.com/shivamsaboo17/Overcoming-Catastrophic-forgetting-in-Neural-Networks