Self-Supervised Image Recognition with IJEPA
Created on June 29|Last edited on July 13
Comment
Supervised learning models, trained on labeled datasets, have achieved impressive results across various domains in the past decade, from image classification to natural language processing. However, supervised learning's reliance on large amounts of labeled data and vulnerability to noisy or adversarial inputs present significant challenges.
Invariant Joint Embedding Predictive Architecture (IJEPA) has emerged as a promising alternative, addressing some of the limitations inherent in other self-supervised learning methods. Unlike traditional self-supervised methods that rely heavily on hand-crafted data augmentations, IJEPA focuses on invariant feature learning and joint embeddings. By predicting missing information in an abstract representation space, IJEPA aims to enhance model robustness and generalization, even with limited labeled data.
We'll show how to train IJEP in the self-supervised regime, without labels, and also how to do linear-probing on the model to examine the quality of model.

The high cost of labelingThe promise of self-supervised learningUnderstanding IJEPA: TheoryHow IJEPA differs from other self-supervised methods Performance on diverse tasksShow me the code Model logging with Weights & BiasesLinear-probing after pre-trainingConclusion
The high cost of labeling
Labeling involves human labor, frequently requiring domain expertise to ensure the annotations are accurate and reliable. For example, in medical imaging, experts such as radiologists or pathologists must review each image to either provide accurate labels or validate ones made by non-experts. This not only incurs significant costs but also introduces delays in the data preparation process. Additionally, maintaining consistent and high-quality labels across large datasets is a challenging task that further increases the complexity and expense.
The promise of self-supervised learning
Self-supervised learning (SSL) methods like IJEPA offer a compelling alternative to traditional supervised learning by reducing the dependency on labeled data. These methods leverage the inherent structure in the data to create supervisory signals, enabling models to learn useful representations without explicit labels. The ability to train models using large volumes of unlabeled data significantly mitigates the costs and time associated with data labeling.
Understanding IJEPA: Theory
IJEPA is designed to learn highly semantic image representations through a self-supervised learning approach. It relies on predicting the representations of target blocks within an image from a single context block, using a strategic masking method to guide the learning process.
IJEPA falls under the category of Joint-Embedding Predictive Architectures, which aim to predict the embeddings of a signal y from a compatible signal x. Unlike traditional generative methods that operate in the input space, IJEPA makes predictions in the latent space, focusing on high-level features rather than pixel-level details.
The masking strategy is crucial to IJEPA's success. It involves context blocks and target blocks. The context block is a sufficiently informative and spatially distributed portion of the image that provides the necessary context for prediction. Given context, the model is tasked with producing a latent representation similar to the latent representation produced by feeding the target blocks through the target encoder.

Energy-Based Models (EBMs) form the underlying principle of IJEPA. In this context, energy represents the compatibility between the predicted embeddings and the actual target embeddings. The goal is to minimize the energy (or difference) between these embeddings, ensuring that similar inputs have low energy (high compatibility) and dissimilar inputs have high energy (low compatibility).
The training objective is to ensure that the predictor produces embeddings for the target blocks that closely match the embeddings generated by the target encoder. The target encoder is essentially a duplicate of the context encoder, but its weights are updated using an exponential moving average (EMA) to stabilize training and prevent representation collapse.
There are three models of interest in IJEPA: the context encoder, the target encoder, and the predictor. The context encoder takes the context patches and generates a set of embeddings. These embeddings, along with a second set of "latent embeddings" from the target encoder, are used by the predictor to predict the target embeddings. The target encoder processes the entire image to generate target embeddings, with its weights updated via EMA to maintain consistency and stability. The predictor takes the context embeddings and a set of latent embeddings (positional tokens) to predict the target block embeddings.

During training, the input image is first converted into non-overlapping patches. Randomly sampled target and context blocks are then extracted from these patches. The context block is fed through the context encoder to produce context embeddings. The predictor network then takes these context embeddings, conditioned on positional tokens, and predicts the target block representations. The average L2 distance between the predicted and actual target embeddings serves as the loss function, guiding the training process.
The predictor and context encoder parameters are updated using standard gradient-based optimization, while the target encoder parameters are updated via EMA to avoid representation collapse. This approach allows IJEPA to effectively learn robust, high-level representations without relying on hand-crafted augmentations, making it a powerful tool for self-supervised learning capable of generalizing across various tasks and data distributions.
Something to note here is that the target encoder recieves the entire image, and the output embeddings are masked, which is somewhat counterintuitive. This subtlety is important, as it helps the model learn high semantic representations of the images.
How IJEPA differs from other self-supervised methods
IJEPA stands out from other augmentation strategies primarily because it eliminates the need for handcrafted data augmentations, such as random cropping, scaling, and color jittering. These traditional augmentations are manually designed and may introduce biases that do not generalize well across different tasks or data distributions. In contrast, IJEPA's approach of predicting target block representations from a context block within the same image allows the model to learn more generalized and unbiased features.
Efficiency compared to other methods
One of the key benefits of IJEPA is its computational efficiency. IJEPA significantly reduces the computational resources needed for pretraining compared to methods that rely on multiple augmented views. For example, pretraining a Vision Transformer (ViT)-Huge/14 model with IJEPA on ImageNet requires less than 1200 GPU hours. This is over 2.5 times faster than pretraining a ViT-Small/16 with iBOT and more than 10 times more efficient than pretraining a ViT-Huge/14 with Masked Autoencoders (MAE) .
This efficiency is achieved because IJEPA only needs to process a single view of each image, rather than generating and processing multiple augmented views. By predicting in representation space rather than input space, IJEPA reduces the overall computational burden, making it a more scalable and practical solution for large-scale self-supervised learning.
Performance on diverse tasks
Additionally, IJEPA demonstrates competitive or superior performance on various tasks beyond classification, such as object counting and depth prediction. These tasks benefit from IJEPA's ability to capture both high-level semantic features and local image details, without being influenced by augmentation-induced distortions.
In summary, while handcrafted augmentations can achieve high performance, IJEPA's approach provides a more flexible, efficient, and generalizable solution. It demonstrates superior scalability and efficiency, making it a promising approach for self-supervised learning in diverse applications.
Show me the code
We will be building off of Meta’s official repo, however, if you are following along with this tutorial, I would reccomend looking at my repo which includes all of the code used in this tutorial.
The underlying architecture of our models revolves around Vision Transformers (ViTs). For self-supervised training, I will be using the 'vit_tiny' variant. The 'vit_tiny' model features an embedding dimension of 192 and a depth of 12 layers. The Vision Transformer consists of a patch embedding layer that divides the input image into non-overlapping patches and projects them into a high-dimensional space. This is followed by multiple transformer blocks, each incorporating multi-head self-attention mechanisms and feed-forward neural networks. The output embeddings are then normalized using a layer normalization.
For this tutorial, I found that the TinyViT trained on Cifar-10 was insufficient to get get downstream performance with a simple linear layer added to the frozen backbone. However, later on we can use Meta's fine tuned model for downstream classification.
For the IJEPA model training setup, an additional component called the predictor is introduced along with the two encoders. The predictor uses the embeddings generated by the encoder, along with positional embeddings (latent variables telling the predictor where it needs to predict the unknown target patches), to predict the embeddings of the masked target blocks. This helps the model learn to infer missing parts of the image based on the context provided by the visible patches.
In contrast, the regular ViT model trained purely under supervised learning does not use a predictor. Instead, it relies solely on the encoder to process the entire input image and directly generate classification outputs. The outputs from the transformer are average pooled, and a classifier layer is applied to produce the final predictions. This distinction highlights how IJEPA leverages additional predictive tasks to enhance its feature learning capabilities, whereas the regular ViT focuses on direct supervised learning from labeled data.
Heres a core chunk of our IJEPA training script, which trains the models to learn representations of the images without labels:
for epoch in range(start_epoch, num_epochs):logger.info('Epoch %d' % (epoch + 1))# -- update distributed-data-loader epochtrain_sampler.set_epoch(epoch)loss_meter = AverageMeter()maskA_meter = AverageMeter()maskB_meter = AverageMeter()time_meter = AverageMeter()encoder.train()predictor.train()optimizer.zero_grad()for itr, (udata, masks_enc, masks_pred) in enumerate(train_loader):def load_imgs():# -- unsupervised imgsimgs = udata[0].to(device, non_blocking=True)masks_1 = [u.to(device, non_blocking=True) for u in masks_enc]masks_2 = [u.to(device, non_blocking=True) for u in masks_pred]return (imgs, masks_1, masks_2)imgs, masks_enc, masks_pred = load_imgs()maskA_meter.update(len(masks_enc[0][0]))maskB_meter.update(len(masks_pred[0][0]))def train_step():_new_lr = scheduler.step()_new_wd = wd_scheduler.step()# --def forward_target():with torch.no_grad():h = target_encoder(imgs)h = F.layer_norm(h, (h.size(-1),)) # normalize over feature-dimB = len(h)# -- create targets (masked regions of h)h = apply_masks(h, masks_pred)h = repeat_interleave_batch(h, B, repeat=len(masks_enc))return hdef forward_context():z = encoder(imgs, masks_enc)z = predictor(z, masks_enc, masks_pred)return zdef loss_fn(z, h):loss = F.smooth_l1_loss(z, h)loss = AllReduce.apply(loss)return loss# Step 1. Forwardwith torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=use_bfloat16):h = forward_target()z = forward_context()loss = loss_fn(z, h)# Step 2. Backward & stepif use_bfloat16:scaler.scale(loss).backward()else:loss.backward()grad_stats = grad_logger(encoder.named_parameters())# Accumulate gradientsif (itr + 1) % accum_steps == 0:if use_bfloat16:scaler.step(optimizer)scaler.update()else:optimizer.step()optimizer.zero_grad()# Step 3. momentum update of target encoderwith torch.no_grad():m = next(momentum_scheduler)for param_q, param_k in zip(encoder.parameters(), target_encoder.parameters()):param_k.data.mul_(m).add_((1.-m) * param_q.detach().data)return (float(loss), _new_lr, _new_wd, grad_stats)(loss, _new_lr, _new_wd, grad_stats), etime = gpu_timer(train_step)loss_meter.update(loss, imgs.size(0))time_meter.update(etime, imgs.size(0))if (itr + 1) % log_freq == 0:logger.info(f'Epoch [{epoch + 1}/{num_epochs}], Step [{itr + 1}/{ipe}], 'f'Loss: {loss_meter.avg:.4f}, Time: {time_meter.avg:.2f}s')# Validation stepval_loss = validate(val_loader, encoder, predictor, target_encoder)logger.info(f'Epoch [{epoch + 1}/{num_epochs}], Validation Loss: {val_loss:.4f}')# Log to wandbrun.log({'Epoch': epoch + 1,'Train Loss': loss_meter.avg,'Validation Loss': val_loss,})# Save the best model based on validation lossif val_loss < best_val_loss:best_val_loss = val_losspth = save_checkpoint(epoch)best_model_path = pthrun.log_model(path=best_model_path, name="ijepa_best_model", aliases=["best"])run.finish()
The core components of the training script for Invariant Joint Embedding Predictive Architecture revolve around the use of three main models: the encoder, the predictor, and the target encoder. Additionally, the script utilizes Weights & Biases for logging important training metrics.
The encoder is responsible for processing the context blocks from the input images and generating their embeddings. During each iteration, context patches are extracted from the images and passed through the encoder to produce context embeddings.
The predictor takes the context embeddings generated by the encoder along with positional embeddings of the masks and predicts the embeddings of the target blocks. This helps the model learn to predict missing parts of the image based on the context.
The target encoder processes the entire image to generate target embeddings, and then these embeddings are subsequently masked according to the given target region. These masked embeddings essentially serve as a reference for what the correct target embeddings should look like. The weights of the target encoder are updated using an exponential moving average (EMA) of the encoder’s weights to stabilize training and prevent representation collapse.
The main IJEPA logic is shown in this section:
def forward_target():with torch.no_grad():h = target_encoder(imgs)h = F.layer_norm(h, (h.size(-1),)) # normalize over feature-dimB = len(h)# -- create targets (masked regions of h)h = apply_masks(h, masks_pred)h = repeat_interleave_batch(h, B, repeat=len(masks_enc))return hdef forward_context():z = encoder(imgs, masks_enc)z = predictor(z, masks_enc, masks_pred)return zdef loss_fn(z, h):loss = F.smooth_l1_loss(z, h)loss = AllReduce.apply(loss)return loss# Step 1. Forwardwith torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=use_bfloat16):h = forward_target()z = forward_context()loss = loss_fn(z, h)# Step 2. Backward & stepif use_bfloat16:scaler.scale(loss).backward()else:loss.backward()
Here we define functions that generate output embeddings from our context blocks and target blocks, and then we compute a loss using these two embeddings. Here, the goal is to generate embeddings that are equal, thus producing meaningful representations of the image.
Model logging with Weights & Biases
We also use the log_model function provided by Weights & Biases to save our best model at the end of training. If we navigate to the 'artifacts' pane in our Weights & Biases project, we will see our model as shown below:

At the top right, you will see a button that says "Link to registry" which will allow us to add the model to our model registry in W&B. After adding the model to your model registry, you can navigate to the registry, and find instructions on how to download the model later! Here is a code snippet showing how to load the model!

Here are the logs for my training run. Initially, I attempted to use a batch size of 32, however, the model failed to improve. After increasing the batch size to 2048, the results were much better!
Run: tough-gorge-27
1
Linear-probing after pre-training
Now, you may be wondering how inference is executed using IJEPA models, since the pre-training process involves masking with 3 separate models, without labels or class tokens. Here we will use the “target encoder” to generate the embeddings, and then average pool each output embedding before feeding it through a final linear layer to generate output predictions from the model.
Earlier, we pretrained a tiny_vit. However, in practice we typically would want a much larger model like vit_huge trained on a larger dataset like ImageNet 22k, so I decided not to do linear probing with our original model, and instead I went ahead a used Meta's checkpoint for the vit_huge model trained on ImageNet 22k. In order to use this model for fine-tuning, you will need to add a bit of code to the existing model definition. First off, you will need to average out the embeddings over the last layer of the transformer. Secondly, you will need to add a linear classifier layer that will ouput the correct class for the model. Here's a short code snipped showing the modification:
def forward(self, x, masks=None, return_avg_embed=False):# ....... rest of forward method omittedif return_avg_embed:avg_embed = x.mean(dim=1) # avg poollogits = self.classifier(avg_embed) # linear layerreturn logits
Now, we will also need to modify our load_checkpoint method to ignore the classifier weights as they are not in the checkpoint we will load, so we set the 'strict' flag equal to false:
# -- loading target_encoderif target_encoder is not None:print(list(checkpoint.keys()))pretrained_dict = checkpoint['target_encoder']msg = target_encoder.load_state_dict(pretrained_dict, strict=False) ####### added for classifier ########logger.info(f'loaded pretrained encoder from epoch {epoch} with msg: {msg}')
Setting the strict flag equal to true essentially tells Torch we are okay with using layers in the model that aren't included in the checkpoint. Now, we are ready to freeze our model, while leaving our linear layer unfrozen. This essentially lets us train our model while only adjusting the weights of the output linear layer:
for name, param in target_encoder.named_parameters():print(f"Parameter name: {name}, shape: {param.shape}")if 'classifier' in name:param.requires_grad = Trueelse:param.requires_grad = False
Now we are ready to train our model! Here's the main training loop:
for epoch in range(0, num_epochs):logger.info('Epoch %d' % (epoch + 1))# -- update distributed-data-loader epochtrain_sampler.set_epoch(epoch)loss_meter = AverageMeter()time_meter = AverageMeter()target_encoder.train()optimizer.zero_grad()for itr, (udata, masks_enc, masks_pred) in enumerate(train_loader):imgs = udata[0].to(device, non_blocking=True)labels = udata[1].to(device, non_blocking=True)def train_step():_new_lr = scheduler.step()_new_wd = wd_scheduler.step()outputs = target_encoder(imgs, return_avg_embed=True)loss = F.cross_entropy(outputs, labels)loss = AllReduce.apply(loss)# Step 1. Backward & stepif use_bfloat16:scaler.scale(loss).backward()else:loss.backward()# Accumulate gradientsif (itr + 1) % accum_steps == 0:if use_bfloat16:scaler.step(optimizer)scaler.update()else:optimizer.step()optimizer.zero_grad()return (float(loss), _new_lr, _new_wd)(loss, _new_lr, _new_wd), etime = gpu_timer(train_step)loss_meter.update(loss, imgs.size(0))time_meter.update(etime, imgs.size(0))if (itr + 1) % log_freq == 0:logger.info(f'Epoch [{epoch + 1}/{num_epochs}], Step [{itr + 1}/{ipe}], 'f'Loss: {loss_meter.avg:.4f}, Time: {time_meter.avg:.2f}s')# Validation stepval_loss, val_accuracy = validate(val_loader, target_encoder)logger.info(f'Epoch [{epoch + 1}/{num_epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%')# Log to wandbrun.log({'Epoch': epoch + 1,'Train Loss': loss_meter.avg,'Validation Loss': val_loss,'Validation Accuracy': val_accuracy,})# Save the best model based on validation lossif val_loss < best_val_loss:best_val_loss = val_losspth = save_checkpoint(epoch)best_model_path = pthrun.log_model(path=best_model_path, name=f"supervised_best_model_fraction_{data_fraction}", aliases=["best"])run.finish()
We are able to reuse our previous dataloader, however since we are simply doing supervised training, we don't use the masks for the predictor or encoder model, and we only use the 'target_encoder' model (frozen) along with an added linear layer (trainable).
This training loop simply uses cross-entropy loss, and the model is tasked with predicting the correct label. We also use W&B logging to log the training and validation losses and accuracy. Something important to note for linear probing with IJEPA is that the learning rate parameter is extremely important. I had the most success with a learning rate of 0.05 for Cifar-10.
It's important to keep in mind that the following results were obtained from only training the linear layer on top of our base vit_huge model. In a world where models become increasingly large, a important aspect of the model is how well it can generalize from large amounts of data, and be easily fine-tuned without unfreezing the weights of the entire model, and this is an area that IJEPA seems to excel at.
After just a few epochs, we were able to achieve over 78% accuracy on Cifar-10, which is not bad considering we are only training a linear layer on top of the base model.
Run: jepa_continual_ft_fraction_1.0
1
Conclusion
Machine learning has made significant strides, especially with supervised learning, but it's not without its challenges, such as the need for large amounts of labeled data. IJEPA presents an exciting new approach to overcome these issues.
IJEPA stands out because it doesn't rely on the traditional hand-crafted data augmentations. Instead, it focuses on learning invariant features and joint embeddings, making it robust and capable of generalizing well even with limited labeled data. This method allows IJEPA to excel across a wide range of tasks, from image classification to object counting and depth prediction, without the biases that can come from data augmentations.
In conclusion, IJEPA offers a transformative approach to self-supervised learning. Its ability to handle diverse tasks with greater efficiency and less reliance on labeled data marks a significant step forward in the field. As machine learning continues to evolve, innovations like IJEPA will be crucial in pushing the boundaries of what AI can achieve.
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.