Skip to main content

What Is Noise Contrastive Estimation Loss? A Tutorial With Code

A tutorial covering the Noise Contrastive Estimation Loss, a commonly encountered loss function in Self Supervised Learning
Created on February 8|Last edited on February 14

Introduction

Noise Contrastive Estimation appears a lot in the field of Self Supervised Learning, mostly commonly in its more modern forms such as the InfoNCE, InfoNCE++ and the NT-Xent. In this article we'll go over the motivation, derivation and implementations using various frameworks.

📋 Table of Contents





Noise Contrastive Estimation Loss

We know about Self-Supervised Learning but let's try to abstract away the framework. I prefer the way Oriol Vinyals describes it in his NeurIPS 2020 Invited Talk. Given our data, and some task specification (we need to provide target information in some way!) we model the data distribution using parts of the input data.
But why do we approach it in this way? Essentially, we're trying to limit the number of parameters that we need to learn. In Supervised Learning we are trying to model our entire input distribution with the targets, which essentially boils down to learning a partition function which splits the distribution into various groups (we can think of these as classes).
Noise Contrastive Estimation (NCE) aims to estimate the input distribution but without learning a full partition function. This approach has proven very effective in language modeling. The problem of estimating the next word can be approached using the NCE as training a sigmoid binary classifier for the next target word and a few samples from some noise distribution and our aim is to train the logits so that the probability of the correct words are close to 1 and those of incorrect words are close to 0.
Simple probabilistic modeling! Now let's build on top of this.

InfoNCE

With a simple reformulation of NCE, instead of training for a sigmoid binary classifier we can add a softmax non-linearity and we cast our probabilities to the range [0,1]\large [0, 1]. This simple reformulation allows us to then estimate the mutual information between two data points and serves as the backbone for most other losses in Self-Supervised Learning.
Assuming we have representations of our original data vi\large v_i (maybe the output of some encoder), we can define the InfoNCE loss for contrastive learning as follows:
InfoNCEθ=1Bi=1Blogf(vi,vi)1Bjexpf(vi,vj)\huge \text{InfoNCE}_{\theta} = \frac{1}{B} \displaystyle \sum_{i=1}^{B} \log \frac{f (v_i, v_i')}{\frac{1}{B} \sum_j \exp f(v_i, v_j')}

...where f\large f is some function which estimates the similarity between the representations. Here, vi\large v_i' denotes representations which are related to vi\large v_i.
Using the same format as above let's distill the idea of our loss function. The idea is to train a softmax classifier (which allows us to cast the probabilities in the range [0,1]\large [0, 1]) between target representations vi\large v_i' and a few samples from some noise distribution.
A much simpler definition (yet more mathematical and probabilistic) is the one described in Representation Learning with Contrastive Predictive Coding:
LN=EX[logfk(xt+k,ct)xjXfK(xj,ct)]\huge \mathcal{L}_N = - \mathbb{E}_X [ \log \frac{f_k(x_{t+k}, c_t)}{\sum_{x_j \in X} f_K(x_j, c_t)}]

Where X={x1...xN}X = \{ x_1 ... x_N\} is a set of N N random samples "containing one positive sample from p(xt+Kct)p(x_{t+K} | c_t) and N1N-1 negative samples from the "proposal" distribution p(xt+k)p(x_{t+k})".
The general InfoNCE definition results in logits that estimate a full softmax function, and gives representations that maximise the lower bound on mutual information between various correlated views.
In most methods we don't actually explicitly specify negative samples, but rather take all other samples in a batch as the negative samples such as in DINO.

Code

Let's look at a very basic abstract implementation of the InfoNCE loss function in PyTorch. [Source]
import torch

def infoNCE(query, positive_key, temperature=0.1, reduction='mean') -> torch.Tensor:
# Negative keys are implicitly off-diagonal positive keys.
# Cosine between all combinations
logits = query @ positive_key.transpose(-2, -1)

# Positive keys are the entries on the diagonal
labels = torch.arange(len(query), device=query.device)

return F.cross_entropy(logits / temperature, labels, reduction=reduction)

📚 Resources

👋 Summary

That wraps our short tutorial on the Noise Contrastive Estimation Loss. For more related articles on Self-Supervised Learning If you want more reports covering the math and "from-scratch" code implementations, let us know in the comments down below or on our forum ✨!
Check out these other reports on Fully Connected covering other Self Supervised Learning topics such as the DINO framework.

Iterate on AI agents and models faster. Try Weights & Biases today.