Skip to main content

CapPa: Training vision models as captioners

Open-source reproduction of "Image Captioners are Scalable Vision Learners Too"
Created on June 25|Last edited on July 8


Introduction

Vision models are commonly used in diverse tasks: classification, segmentation, captioning, image generation/editing and now, even in LLMs for vision capabilities. They're typically pre-trained either as classifiers (using ImageNet or similar) or with contrastive techniques based on CLIP.
In "Image Captioners Are Scalable Vision Learners Too", the authors show that we can train competitive vision models as image "captioners" on noisy dataset (the same as the ones used for CLIP). It's an interesting alternative as captioning requires a detailed understanding of an image so CapPa could lead to more powerful vision models.
This project offers an open-source reproduction (full code in JAX/Flax including training, use of an open-source dataset, model weights) of the paper while exploring a few additional techniques.

Objective

The model is trained using:
  • Captioning: We feed images as inputs and try to predict the caption
  • Masked prediction: We mask part of the caption and predict all the masked tokens simultaneously
Source: Image Captioners Are Scalable Vision Learners Too, Figure 1

The Cap variant uses only the captioning objective while the CapPa variant also adds the masked prediction.
In our experiment we follow the CapPa variant and use following combination of objectives as recommended in the paper:
  • Captioning: 25% of the time
  • Masking: we drop the entire caption 75% of the time
Cap models already perform well but amazingly enough, dropping the entire caption 75% of the time (as in CapPa) leads to even better results.
Intuitively this is very interesting because we can imagine how having the start of the caption can have more impact in predicting the next word versus the actual image, but our objective here is to train a strong vision model. In fact the decoder (text model with cross-attention) is typically discarded and we only retrain the vision model for downstream tasks.

Dataset

  • We use DataComp 1B and remove exact duplicates of image + caption.
  • One image can have multiple captions and a caption can be used for different images.
  • The validation set is made of images that don’t have any other exact duplicate (regardless of the caption).
  • During training, we don’t apply any augmentation on images.
  • Images are resized to 256 x 256 (TPUs like multiples of 128 and it limits storage/egress costs)
  • About 45% of images have an edge smaller than 256 pixels but we still keep them as a form of data augmentation.
  • We drop the entire caption 75% of the time (with masking objective instead of captioning).

Model architecture

The architecture is mostly inspired from "An Evaluation of Transformer Variants".
Number of parameters:
  • Vision: 328M
  • Text: 348M (includes 67M for embeddings)
Vision model:
  • Patches of 16x16 on images of 256x256 pixels, ie 256 patches per image
  • Learnt positions
  • MLP uses GeGLU activations (linear + GELU)
  • No bias
  • RMSNorm with Normformer positions (start + end of Attention blocks & start + mid of Feed-forward blocks)
  • 24 layers, hidden dimension of 1024, mlp dimension of 3072
  • Use of vision registers (see "Vision Transformers Need Registers")
    • we use 8 as it trains faster on TPU’s vs less
    • we don’t discard any of the registers as we use cross-attention to text vs CLS token in paper ; in downstream tasks we would use a MAP head keeping all tokens + registers
Text model:
  • Decoder (causal when captioning) with cross-attention on images
  • Same as vision except 12 layers (half)
  • Max length of 64 tokens
  • Bias only on final layer predicting vocab

Training

Training Configuration

  • Hardware: TPU v5e-256
  • Framework: JAX/Flax
  • Optimizer: distributed shampoo with RSMProp normalized grafting, based on "Evaluation of Distributed Shampoo"
  • No weight decay, no dropout
  • Weights in float32, computation in bfloat16 except for attention logits, normalization layers, and loss
  • Sharding: FSDP, model and data dimension sharded across all devices
  • Batch size of 8,192 - Note that 16k also fits well with this sharding strategy
  • Training speed: 0.45 seconds / batch, 1.6B samples per day

Training metrics (updated live)



Notes:
  • The validation loss uses only the captioning objective, which explains why it is much lower than the training loss.
  • To derive top 1/5 scores on ImageNet:
    • We use the lowest softmax cross-entropy of each possible class against each image as we cannot compute text embeddings separately like with CLIP models
    • Since it is much slower to perform, we only use 5% of the data
    • It is expected to be lower than if we fine-tune a CLIP model (SigLiT style) as it is not trained explicitly for that task and the current method may be affected by different token length for each class
  • You can explore additional metrics (gradients, parameter norms…) in the main workspace or in individual runs:

Sample predictions (updated live)





Training Log

Results

Visual inspection

One of the big advantages of CapPa models vs CLIP is that we have access to a clearer interpretation. Performing inference on images let us efficiently probe and visualize the model knowledge:
  • the model is good at OCR
  • it has strong cultural knowledge
  • some details may be hallucinated
  • there is likely leakage from near duplicates in training data (see last line of below table)
Here are a few interesting samples extracted from "Sample Predictions" section:
Image Ground Truth Prediction
Sample Image 1 kristen coates art and home holiday 01.png christmas coates holiday shop
Sample Image 2 my sheltie loves agility design, sable tshirts sheltie agility t shirt
Sample Image 3 the poetry of living off the grid the road washes out in spring
Sample Image 4 articles of confederation are ratified the constitution
Sample Image 5 how to draw spider man full body how to draw a man side view
Sample Image 6 hotel arenal bromelias, view from the street hostal la merced
Sample Image 7 the prince fielder era came to an end for the tigers over the offseason prince fielder
Sample Image 8 south shore fynn 2 drawer nightstand gray oak, 0 south shore fynn collection nightstand, gray oak



CapPa to SigLiT

CapPa shows strong performance even on zero-shot accuracy on ImageNet (63% top 1 and 85% top 5, see "Training" section) while not being trained on classification or contrastive objective.
When freezing the vision tower and training only a small text tower, such as SigLiT in "Sigmoid Loss for Language Image Pre-Training", we reach 74% top 1 and 93% top 5. Training is fast and efficient, we trained for 6B samples but we can see that we were even starting to overfit and could have probably trained on fewer samples.
In fact, the purpose of CapPa is to have a strong vision tower for downstream applications so we can adjust the architecture and quickly train a model for contrastive learning (such as here with SigLiT), classification, object detection, OCR, segmentation, text to image encoders and even VLM.





Composition understanding - SugarCrepe

The model beats every open-source model on all categories but one on SugarCrepe.
This benchmark aims to probe models at fine compositional understanding by replacing, swapping or adding objects/attributes/relations in a plausible manner such that a simple decoder could not identify the most likely response based on occurrence in the dataset.
Source Model Data Size Model Size (M) Replace Swap Add
Object Attribute Relation Object Attribute Object Attribute
Human 100 99 97 99 100 99 99
Text-only model Vera 49.39 49.62 49.36 49.19 49.40 49.42 49.57
Grammar 50.00 50.00 50.00 50.00 50.00 50.00 50.00
OpenAI RN50 102 91.77 80.58 69.99 61.79 68.47 74.54 69.65
RN101 120 92.49 83.88 67.07 56.50 65.92 75.46 70.09
ViT-B-32 151 90.92 80.08 69.20 61.38 63.96 77.21 68.79
RN50x4 400M 178 92.68 82.99 67.57 65.04 63.36 79.34 70.09
RN50x16 291 93.46 82.11 69.20 63.01 65.77 80.70 75.87
ViT-L-14 428 94.07 79.19 65.15 60.16 62.31 78.32 71.53
RN50x64 623 94.49 83.50 70.63 61.79 66.67 83.27 73.99
LAION roberta-ViT-B-32 212 92.86 84.90 72.40 63.01 71.02 87.34 79.91
ViT-H-14 2B 986 96.49 84.77 71.76 67.48 73.12 92.05 85.84
ViT-g-14 1367 95.76 85.03 72.40 63.01 71.17 91.51 82.08
ViT-bigG-14 2540 96.67 88.07 74.75 62.20 74.92 92.19 84.54
xlm-roberta-base-ViT-B-32 366 93.16 84.01 69.20 63.41 67.57 87.78 81.07
xlm-roberta-large-ViT-H-14 5B 1193 96.85 86.04 72.05 63.82 72.07 93.11 86.13
DataComp small:ViT-B-32 13M 151 56.90 56.85 51.99 50.81 50.00 53.93 60.55
medium:ViT-B-32 128M 151 77.00 69.54 57.68 57.72 57.06 66.73 64.88
large:ViT-B-16 1B 150 92.68 79.82 63.94 56.10 57.66 84.34 78.61
xlarge:ViT-L-14 13B 428 95.52 84.52 69.99 65.04 66.82 91.03 84.97
Open-source CapPa ViT-L-16 with registers 1B 676 89.65 86.80 82.50 77.14 85.14 98.25 99.42



Registers sttention maps

The main observation of "Vision Transformers Need Registers" is that ViT models often leverage patches with low information (such as uniform background) to encode global information. Adding registers gives the model a better place to encode the global information.
In their implementation, a BOS token is used for final pooling. Since we have a captioner, the visual tokens are pooled through cross-attention. We decide to keep both the visual patches and the registers in the cross-attention.
In our CapPa model, we can visualize the attention from the image to any register:

In this example the rabbits have a stronger attention to "register 2" than the rest of the image.


We can also see the attention from the register to the image, which is probably more interesting.

Here we can see that "register 2" pays more attention to the rabbit ears.
The attention weights are all within a relatively bounded range without high noise to random patches on the image so we can expect that the registers are efficiently used to encode global information.
As a side note, we can automatically retrieve intermediate outputs in JAX/Flax using capture_intermediates.
output, mod_vars = model.apply(
{"params": params},
**batch,
# NOTE: added to capture intermediate values
capture_intermediates=True, mutable=["intermediates"]
)

attention_weights = mod_vars["intermediates"]["vision"]["encoder"]["layers"]["attention"]["attention_weights"]
In our examples we displayed the maps from first attention head. The norm and mean over attention heads is less interesting as each head attends to different parts of the image.

How to use the model

Refer to the CapPa inference notebook from the Clip-JAX repo.
You can jit the function calls for faster inference as in the demo notebook. The model generation interface supports greedy search, temperature sampling with top_p/top_k, as well as beam search.
@partial(
jax.jit,
static_argnames=("num_beams", "do_sample", "temperature", "top_p", "top_k", "max_length", "num_return_sequences"),
)
def generate_caption(pixel_values, *args, **kwargs):
return model.generate(pixel_values, *args, **kwargs)
Please keep in mind that the main purpose of the CapPa model is to only keep the vision tower and discard the text tower for downstream applications.
The model is also being ported to 🤗 Transformers.

Resources

Acknowledgements

Iterate on AI agents and models faster. Try Weights & Biases today.