Skip to main content

Navigating Over-parametrized Feature Space with Meta-Gradients

Created on July 5|Last edited on July 12

Introduction

Meta-learning is well-known for the promise of "learning how to learn", and has been quite extensively studied in the recent machine learning research literature. On a high level, a meta-learning algorithm differs from other single-purpose learning methods in that: it takes in a distribution of related tasks, and tries to find a generally-useful meta-learner that can adapt to all these tasks; then at test time, it will be able to quickly learn an unseen variant from the trained task distribution, requiring a much smaller number of examples than having to learn this new task from scratch. To give a concrete example, one popular instantiation of meta-learning is the few-shot classification problem, where each "task" is trying to predict the correct image label from a few (1~5) samples, and the tasks differ in that each uses different data samples to predict the same image classes. At test time, the meta-learned model must classify novel inputs that are not in the training dataset.
In this report, we aim to provide more intuitions for how and why meta-learning works for these multi-task setups, as well as explore a slight different aspect of the capability of meta-learning methods. We focus on a bare-bone problem setting of learning 1D functions from Fourier features in an over-parameterized space: where we will see that meta-gradients (a specific meta-learning method that incorporates well with gradient-descent parameter updates) can easily sort through the high-dimensional feature space and recover the "truly useful" ones underneath, so as to then excel on any task with labels generated from these true features and are otherwise not achievable via direct single-task supervised learning due to over-parametrization.

Problem Setup

Let's consider a minimalist formulation: a linear regression problem y=cT(wΦ(x))y = c^T (w * \Phi(x)), where:
  • xR,Φ(x)=[ϕ1(x),...,ϕd(x)]Rdx \in \mathbb{R}, \Phi(x) = [\phi_1(x), ..., \phi_d(x) ] \in \mathbb{R}^d: Φ\Phi is the Fourier featurizing function and each ϕk(x)=ej2πkx,k{1,...,d}\phi_k(x) = e^{j2\pi kx}, k \in \{1,...,d\}
  • wRdw \in \mathbb{R}^d are called feature weights: it has the same shape as Φ(x)\Phi(x), and thus puts a scalar weight on each feature ϕk(x)\phi_k(x)
  • cRdc \in \mathbb{R}^d are called feature coefficients: they parametrize a linear combination on the weighted features (wΦ(x))(w * \Phi(x)), effectively down-projecting the dd-dimensional weighted features to a label yRy \in \mathbb{R}
Then, we sample from a grid of xx values and featurize them to get a small batch of input data Φ(X)\Phi(X)  and labels YY:
  • XRN×1,Φ(X)RN×dX \in \mathbb{R}^{N \times 1}, \Phi(X) \in \mathbb{R} ^{N \times d} : a "small" batch because we sample fewer data-points than the size of the feature space, i.e. NN < dd
  • Y=cT(wΦ(X)),YRN×1Y = c^T (w * \Phi(X)), Y \in \mathbb{R}^{N \times 1}: now a batch of 1-dimensional labels
  • For all experiments below, we fix to using d=227,N=64d = 227, N = 64
  • Because all of xx are sampled in the same way, we can define one task by a specific pair of weights: (w,cw, c), which allows us to generate and sample data for this task in a controllable fashion.
Below, we first visualize one such batch of data:

Visualize individual features

Given one batch of XX, we generate a concatenation of dd Fourier features: Φ(X)=[ϕ1(X),...,ϕd(X)]TRN×d\Phi(X) = [\phi_1(X), ..., \phi_d(X) ] ^T \in \mathbb{R}^{N \times d}. In each scatter plot below, we pick one k, k{0,1,...,d}, k, \ k \in \{0, 1, ..., d \}, \  and visualize what ϕk(X)\phi_k(X) looks like over a batch of XX sampled within the interval [-1, 1]. For all the experiment runs, we fix to using the same featurization, and it's evident how the features show higher frequency as we increase the feature index k.k.


Run set
3


Visualize the Y labels


After fixing a set of feature weights w\mathbf{w}, here we use index 0,1,2 of the Fourier features and weight all other features with 0, i.e. a one-hot vector w=[1,1,1,0,...,0]R127\mathbf{w} = [1,1, 1, 0,..., 0] \in \mathbb{R}^{127} , the remaining degree of freedom in defining an exact task is the set of feature coefficients c\mathbf{c}: to provide coverage for a diversity of different tasks, we randomly sample from a Gaussian distribution: cN(μ=0,σ=1)\mathbf{c} ~ N(\mu = 0, \sigma = 1) , and generate yy  label for each of them according to y=cT(wΦ(x))y = c^T (w * \Phi(x)). Therefore, despite using the same feature weights w,\mathbf{w}, we can see the "True Label Y" visualizations from the experiment runs show different value distributions over the same grid of input XX. In addition to the training labels, we can also run a quick close-form solution on the small data batch via
from sklearn import linear_model
reg = linear_model.LinearRegression().fit(X, y)
close_y = big_X @ reg.coef_+ reg.intercept_
Predictions from this closed-form solution on training data is shown on the right column below: it is able to perfectly fit all the labels in the training batch (which only contains N=64 datapoints). However, after densely sampling more data from the exact same training distribution, we see this solution fails horribly on un-seen test data (second row below).

5 Inner Tasks
5
15 Inner Tasks
2



Meta-learn Feature Weights w\mathbf{w} by varying task coefficients c\mathbf{c}

As we've seen above, this over-parametrization setup poses a key challenge of having to learn the "truly useful" feature weights w\mathbf{w} as a prerequisite for being able to learn any set of coefficients c\mathbf{c} for a task. While difficult to directly learn, what we can do is take the advantage of being able to generate many different tasks, and utilize meta-learning to discover the commonly-useful w\mathbf{w} underneath.
There are many machine learning techniques that fall under the nebulous heading of meta-learning, here we attempt two of such methods, both focus on optimizing the initial weights of a network (in our setup, w\mathbf{w}) to rapidly converge to low loss across a distribution of varying tasks and learn a good c.\mathbf{c}.  The First approach is named Model Agnostic Meta-Learning, or MAML, which uses second-order gradients to update an outer-loop of parameters; second is a first-order method named Reptile, which only uses weighted parameter averages to update the same "outer loop" of weights. At a high level, both methods work by sampling a “mini-batch” of tasks {Ti}\{T_i\}, all of which share the true underlying w,\mathbf{w}^*,  and alternate between (a) finding the task-specific coefficients ci\mathbf{c}_i  for a range of inner tasks using a fixed wj\mathbf{w}_j, and (b) updating the outer wj\mathbf{w}_j  based on the inner learned tasks' validation performance.

The key difference between MAML and Reptile lies in the way they update the outer feature weights wj\mathbf{w}_j : MAML uses second-order gradients to calculate wj\triangledown \mathbf{w}_j . After iterating through a few inner tasks and updating each ci\mathbf{c}_i by regular gradient descent starting from some initialization c0\mathbf{c}_0 (while keeping wj\mathbf{w}_j fixed), it collects the validation loss across these tasks to calculate meta-gradients w.r.t. the fixed feature weights wj\mathbf{w}_j  and updating them by meta-gradient descent to get wj+1\mathbf{w}_{j+1}. The pseudocode for our specific setting is thus as follows:
Initialize w0\mathbf{w}_0
for outer iteration j=1,2,3,...j = 1, 2, 3, ... do
for num_inner_tasks i=1,2,...ki = 1, 2, ... k do
Randomly sample a task TiT_i
Perform kk steps of SGD on task TiT_i starting with parameters c0\mathbf{c}_0, resulting in parameters ck\mathbf{c}_k
Sample more data from the same task and calculate validation loss lil_i
end for
calculate meta-gradients wj\triangledown \mathbf{w}_j  w.r.t. the averaged validation losses of all inner tasks
Update: wj+1=wjηwj\mathbf{w}_{j+1} = \mathbf{w}_j - \eta \triangledown \mathbf{w}_j 
end for

Whereas Reptile does not require this inner-outer loop distinction: for each newly-sampled task, it does a few regular gradient updates on both wj,ci\mathbf{w}_j, \mathbf{c}_i; when switching to a new task, it discards the learned ci\mathbf{c}_i and updates wj\mathbf{w}_j by a simple weighted copy of the previously-saved copy. The pseudocode is thus:
Initialize w0\mathbf{w}_0
for iteration j=1,2,3,...j = 1, 2, 3, ... do
Randomly sample a task TjT_j
Perform kk steps of SGD on task TiT_i starting with parameters wj,c0\mathbf{w}_j, \mathbf{c}_0, resulting in parameters wj+k,ck\mathbf{w}_{j+k}, \mathbf{c}_k
Update: wj+1=wj+ϵ(wj+kwj)\mathbf{w}_{j+1} = \mathbf{w}_j + \epsilon (\mathbf{w}_{j+k} - \mathbf{w}_j) 
end for
Source: C. Finn, P. Abbeel, S. Levine, ”Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks,” in Proceedings of the 34th International Conference on Machine Learning, Sydney, Australia, PMLR 70, 2017


Results

For both methods (MAML and Reptile), we showcase and compare one set of hyper-parameters and two different ways of weight initializations for w\mathbf{w}: 1) all initializing all feature weights to 1: w0=[1,...,1]\mathbf{w}_0 = [1,...,1] ; 2) uniformly sample each index's weight from the interval [1,1][-1, 1]: w0=[w1,...,wk]T,wkUnif(1,1)\mathbf{w}_0 = [w_1,...,w_k]^T, w_k \sim \text{Unif}(-1,1).

1. MAML Results


MAML
2
Reptile
4


2. Reptile Results


MAML
2
Reptile
2