Intro to Meta-Learning
Meta-learning is a generalization of machine learning
Rather than solve any problem perfectly, meta-learning seeks to improve the process of learning itself. It's appealing from a cognitive science perspective: humans need way fewer examples than a deep net to understand a pattern, and we can often pick up new skills and habits faster if we're more self-aware and intentional about reaching a certain goal.
Higher accuracy with fewer examples
In regular deep learning, we apply gradient descent over training examples to learn the best parameters for a particular task (like classifying a photo of an animal into one of 5 possible species). In meta-learning, the task itself becomes a training example: we apply a learning algorithm over many tasks to learn the best parameters for a particular problem type (e.g. classification of photos into N classes). In Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks from ICML 2017, the meta-learning algorithm is, elegantly, gradient descent, and it works for any inner model type that is itself trained with gradient descent (hence "model-agnostic"). Finn, Abbeel, and Levine apply this to classification, regression, and reinforcement learning problems and tune a meta-model (the outer model) that can learn quickly (1-10 gradient updates) on a new task with only a few examples (1-5 per class for 2-20 classes). How well does this work in practice and how can we best apply the meta-model to new datasets?
Few-shot classification on mini-ImageNet (MIN)
In this report, I focus on MAML for few-shot image classification, instrumenting the original code for the paper. Below are some examples from the mini-ImageNet (MIN) dataset with my best guesses as to the labels (which could be more specific or more general categories in actuality). This is fairly representative of ImageNet: diversity of images and views of the target object, balanced with mostly center crops and strict, not-always-intuitive definitions (e.g. the "wolf" and "bird" classes could more narrowly intend a particular species).
N-way, K-shot image classification
From the MAML paper: "According to the conventional terminology, K-shot classification tasks use K input/output pairs from each class, for a total of NK data points for N-way classification."
Here are the relevant settings (argument flags) in the provided code:
num_classes
: N, as in N-way classification, is the number of different image classes we're learning in each taskupdate_batch_size
: K, as in K-shot learning, is the number of examples seen for each class to update the inner gradient on a task-specific model
So, 5-way, 1-shot MIN considers 1 labeled image from each of 5 classes (a total of 5 images). 5-way, 5-shot MIN considers 5 labeled images from each of 5 classes (a total of 25 images). Some example scenarios are shown below. Note how much the diversity of classes in a given N-way task may vary: e.g. different species of similar-looking dogs or the range of visuals used to represent "lipstick" may be much harder to learn.
Other important flags for training dynamics
meta_batch_size
is the number of tasks sampled before we update the metaparameters of the outer model / metatraining loopnum_updates
is how many times we update the inner gradient /inner model during trainingmetatrain_iterations
is the total number of example tasks that the model sees. The code recommends 60,000 for MIN; I eventually switched to the default 15,000 for efficiency- setting effective batch size to 8 instead of 16 yields comparable performance and much faster training
Initial observations
Here I compare meta-learning runs with K=1 shot learning (1 example for each class) while varying the number of classes (num_classes
), the number of inner gradient updates (num_updates
), the effective batch size, and the number of filters learned. All charts are shown with smoothing 0.8.
- Lower N is easier: intuitively, the fwer classes there are to distinguish, the better the model performs. You can see that across charts, the red (2-way) accuracy is higher than the orange (3-way) than the blues (5-way).
- Number of inner gradient updates beyond 5 doesn't matter much: post-update accuracy at step 10 doesn't change much from step 5 (refer to the first row, right chart, where the labeled curves for step 5 and step 10 are almost entirely overlapping)
- Lower effective batch size trains faster and slightly worse: the original code mysteriously adds 15 to the batch size. Undoing this and setting an effective batch size of 4 only slightly decreases performance and seems to be more consistent with K. Effective batch size by the traditional definitions of meta-learning should be N K, not N (K+15).
- Doubling the number of filters slightly increases performance: as expected with more learning capacity; compare the medium-blue run with 64 filters (suffix "fs64") to the light blue and darker blue runs, which have 32 filters
Use three repos and a gist
- The original repo notes that the mini-ImageNet dataset is available from Ravi & Larochelle '17, which corresponds to the paper Optimization as a Model for Few-Shot Learning from ICLR 2017, and this canonical csv split of train/val/test images.
- The actual images are available in a different repo: few-shot-ssl-public in .pkl format. This repo extends few-shot classification to learning from unlabeled examples and more challenging tasks with distractor (previously unseen) classes (Meta-Learning for Semi-Supervised Few-Shot Classification, Ren et al, ICLR 2018).
- I wrote this gist to extract the images and prepare them for training in MAML. It greatly accelerated my data-extraction process because it saves each image file directly to the right location instead of moving files with a Python os.system("mv...") command.
Training data setup
For mini-ImageNet (MIN), the data is split into train (64 classes), validation (16 classes), and test (20 classes). Each class contains 600 images, each 84 x 84 pixels. data_generator.py
in the main repo randomly picks classes, and randomly picks the right number of samples per class (K in K-shot learning), from the right split depending on the mode ( training, evaluation, or testing). One confusing detail is that the source code increments the inner batch size K by 15 when generating training data, which may affect the correctness of image shuffling. I trained some with and without this modification to try to isolate its impact and necessity.
Next experiments
- Fix N, vary K: how does this affect accuracy?
- Train on omniglot: how are the dynamics the same/different in character classification versus natural images?
- Mix regression and classification, and potentially RL: can we train on Omniglot and finetune on MIN, or vice versa? Is there a general model that could learn across all three problem types?
- Metatraining hyperparameters: I haven't explored many hyperparameters of the metatraining loop
- Test & metatest accuracy over number of metatraining updates: Explore the test accuracy and metatest accuracy of these trained models—how does it correlate with metaval accuracy? How does the number of metatraining updates affect this? For best results on a totally new image dataset, pick the model with the highest metatest accuracy.
- Pre-update and post-update dynamics: How can we intuitively understand the pre-update metrics? How much does the number of metatraining update steps matter (vary
meta_batch_size
)? Does it make sense to compare the change in loss across updates? - Inner gradient dynamics: why does inner validation saturate to random chance in many cases? Why is the postloss mean higher than the preloss mean?
Gotchas
- Some training configurations hit memory limits: for example, 10-way, 10 shot classification errors as follows:
File "/home/stacey/.pyenv/versions/mm/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 1950, in __init__
"Cannot create a tensor proto whose content is larger than 2GB.")
- The default for K is 5, and it appears to be controlling the update_batch_size / number of examples we see for every class in the inner loop. However the original code adds 15 to it for generating MIN data, supposedly to "use 15 val examples for imagenet" according to a comment/TODO. Later in the code, the training data is generated with tf.slice() and the correct, unmodified update_batch_size, but I'm not confident this reliably preserves the intricate sorting of the training data by class...