CapPa: Training vision models as captioners
Open-source reproduction of "Image Captioners are Scalable Vision Learners Too"
Created on June 25|Last edited on July 8
Comment
Introduction
Vision models are commonly used in diverse tasks: classification, segmentation, captioning, image generation/editing and now, even in LLMs for vision capabilities. They're typically pre-trained either as classifiers (using ImageNet or similar) or with contrastive techniques based on CLIP.
In "Image Captioners Are Scalable Vision Learners Too", the authors show that we can train competitive vision models as image "captioners" on noisy dataset (the same as the ones used for CLIP). It's an interesting alternative as captioning requires a detailed understanding of an image so CapPa could lead to more powerful vision models.
This project offers an open-source reproduction (full code in JAX/Flax including training, use of an open-source dataset, model weights) of the paper while exploring a few additional techniques.
Objective
The model is trained using:
- Captioning: We feed images as inputs and try to predict the caption
- Masked prediction: We mask part of the caption and predict all the masked tokens simultaneously

Source: Image Captioners Are Scalable Vision Learners Too, Figure 1
The Cap variant uses only the captioning objective while the CapPa variant also adds the masked prediction.
In our experiment we follow the CapPa variant and use following combination of objectives as recommended in the paper:
- Captioning: 25% of the time
- Masking: we drop the entire caption 75% of the time
Cap models already perform well but amazingly enough, dropping the entire caption 75% of the time (as in CapPa) leads to even better results.
Intuitively this is very interesting because we can imagine how having the start of the caption can have more impact in predicting the next word versus the actual image, but our objective here is to train a strong vision model. In fact the decoder (text model with cross-attention) is typically discarded and we only retrain the vision model for downstream tasks.
Dataset
- We use DataComp 1B and remove exact duplicates of image + caption.
- One image can have multiple captions and a caption can be used for different images.
- The validation set is made of images that don’t have any other exact duplicate (regardless of the caption).
- During training, we don’t apply any augmentation on images.
- Images are resized to 256 x 256 (TPUs like multiples of 128 and it limits storage/egress costs)
- About 45% of images have an edge smaller than 256 pixels but we still keep them as a form of data augmentation.
- We drop the entire caption 75% of the time (with masking objective instead of captioning).
Model architecture
Number of parameters:
- Vision: 328M
- Text: 348M (includes 67M for embeddings)
Vision model:
- Patches of 16x16 on images of 256x256 pixels, ie 256 patches per image
- Learnt positions
- MLP uses GeGLU activations (linear + GELU)
- No bias
- RMSNorm with Normformer positions (start + end of Attention blocks & start + mid of Feed-forward blocks)
- 24 layers, hidden dimension of 1024, mlp dimension of 3072
- we use 8 as it trains faster on TPU’s vs less
- we don’t discard any of the registers as we use cross-attention to text vs CLS token in paper ; in downstream tasks we would use a MAP head keeping all tokens + registers
Text model:
- Decoder (causal when captioning) with cross-attention on images
- Same as vision except 12 layers (half)
- Max length of 64 tokens
- Bias only on final layer predicting vocab
Training
Training Configuration
- Hardware: TPU v5e-256
- Framework: JAX/Flax
- Optimizer: distributed shampoo with RSMProp normalized grafting, based on "Evaluation of Distributed Shampoo"
- No weight decay, no dropout
- Weights in float32, computation in bfloat16 except for attention logits, normalization layers, and loss
- Sharding: FSDP, model and data dimension sharded across all devices
- Batch size of 8,192 - Note that 16k also fits well with this sharding strategy
- Training speed: 0.45 seconds / batch, 1.6B samples per day
Training metrics (updated live)
Notes:
- The validation loss uses only the captioning objective, which explains why it is much lower than the training loss.
- To derive top 1/5 scores on ImageNet:
- We use the lowest softmax cross-entropy of each possible class against each image as we cannot compute text embeddings separately like with CLIP models
- Since it is much slower to perform, we only use 5% of the data
- It is expected to be lower than if we fine-tune a CLIP model (SigLiT style) as it is not trained explicitly for that task and the current method may be affected by different token length for each class
- You can explore additional metrics (gradients, parameter norms…) in the main workspace or in individual runs:
Sample predictions (updated live)
Training Log
Results
Visual inspection
One of the big advantages of CapPa models vs CLIP is that we have access to a clearer interpretation. Performing inference on images let us efficiently probe and visualize the model knowledge:
- the model is good at OCR
- it has strong cultural knowledge
- some details may be hallucinated
- there is likely leakage from near duplicates in training data (see last line of below table)
Here are a few interesting samples extracted from "Sample Predictions" section:
Image | Ground Truth | Prediction |
---|---|---|
![]() |
kristen coates art and home holiday 01.png | christmas coates holiday shop |
![]() |
my sheltie loves agility design, sable tshirts | sheltie agility t shirt |
![]() |
the poetry of living off the grid | the road washes out in spring |
![]() |
articles of confederation are ratified | the constitution |
![]() |
how to draw spider man full body | how to draw a man side view |
![]() |
hotel arenal bromelias, view from the street | hostal la merced |
![]() |
the prince fielder era came to an end for the tigers over the offseason | prince fielder |
![]() |
south shore fynn 2 drawer nightstand gray oak, 0 | south shore fynn collection nightstand, gray oak |
CapPa to SigLiT
CapPa shows strong performance even on zero-shot accuracy on ImageNet (63% top 1 and 85% top 5, see "Training" section) while not being trained on classification or contrastive objective.
When freezing the vision tower and training only a small text tower, such as SigLiT in "Sigmoid Loss for Language Image Pre-Training", we reach 74% top 1 and 93% top 5. Training is fast and efficient, we trained for 6B samples but we can see that we were even starting to overfit and could have probably trained on fewer samples.
In fact, the purpose of CapPa is to have a strong vision tower for downstream applications so we can adjust the architecture and quickly train a model for contrastive learning (such as here with SigLiT), classification, object detection, OCR, segmentation, text to image encoders and even VLM.
Composition understanding - SugarCrepe
The model beats every open-source model on all categories but one on SugarCrepe.
This benchmark aims to probe models at fine compositional understanding by replacing, swapping or adding objects/attributes/relations in a plausible manner such that a simple decoder could not identify the most likely response based on occurrence in the dataset.
See our results below added to Table 6 of "SugarCrepe: Fixing Hackable Benchmarks for Vision-Language Compositionality".
Source | Model | Data Size | Model Size (M) | Replace | Swap | Add | ||||
---|---|---|---|---|---|---|---|---|---|---|
Object | Attribute | Relation | Object | Attribute | Object | Attribute | ||||
Human | 100 | 99 | 97 | 99 | 100 | 99 | 99 | |||
Text-only model | Vera | 49.39 | 49.62 | 49.36 | 49.19 | 49.40 | 49.42 | 49.57 | ||
Grammar | 50.00 | 50.00 | 50.00 | 50.00 | 50.00 | 50.00 | 50.00 | |||
OpenAI | RN50 | 102 | 91.77 | 80.58 | 69.99 | 61.79 | 68.47 | 74.54 | 69.65 | |
RN101 | 120 | 92.49 | 83.88 | 67.07 | 56.50 | 65.92 | 75.46 | 70.09 | ||
ViT-B-32 | 151 | 90.92 | 80.08 | 69.20 | 61.38 | 63.96 | 77.21 | 68.79 | ||
RN50x4 | 400M | 178 | 92.68 | 82.99 | 67.57 | 65.04 | 63.36 | 79.34 | 70.09 | |
RN50x16 | 291 | 93.46 | 82.11 | 69.20 | 63.01 | 65.77 | 80.70 | 75.87 | ||
ViT-L-14 | 428 | 94.07 | 79.19 | 65.15 | 60.16 | 62.31 | 78.32 | 71.53 | ||
RN50x64 | 623 | 94.49 | 83.50 | 70.63 | 61.79 | 66.67 | 83.27 | 73.99 | ||
LAION | roberta-ViT-B-32 | 212 | 92.86 | 84.90 | 72.40 | 63.01 | 71.02 | 87.34 | 79.91 | |
ViT-H-14 | 2B | 986 | 96.49 | 84.77 | 71.76 | 67.48 | 73.12 | 92.05 | 85.84 | |
ViT-g-14 | 1367 | 95.76 | 85.03 | 72.40 | 63.01 | 71.17 | 91.51 | 82.08 | ||
ViT-bigG-14 | 2540 | 96.67 | 88.07 | 74.75 | 62.20 | 74.92 | 92.19 | 84.54 | ||
xlm-roberta-base-ViT-B-32 | 366 | 93.16 | 84.01 | 69.20 | 63.41 | 67.57 | 87.78 | 81.07 | ||
xlm-roberta-large-ViT-H-14 | 5B | 1193 | 96.85 | 86.04 | 72.05 | 63.82 | 72.07 | 93.11 | 86.13 | |
DataComp | small:ViT-B-32 | 13M | 151 | 56.90 | 56.85 | 51.99 | 50.81 | 50.00 | 53.93 | 60.55 |
medium:ViT-B-32 | 128M | 151 | 77.00 | 69.54 | 57.68 | 57.72 | 57.06 | 66.73 | 64.88 | |
large:ViT-B-16 | 1B | 150 | 92.68 | 79.82 | 63.94 | 56.10 | 57.66 | 84.34 | 78.61 | |
xlarge:ViT-L-14 | 13B | 428 | 95.52 | 84.52 | 69.99 | 65.04 | 66.82 | 91.03 | 84.97 | |
Open-source CapPa | ViT-L-16 with registers | 1B | 676 | 89.65 | 86.80 | 82.50 | 77.14 | 85.14 | 98.25 | 99.42 |
Registers sttention maps
The main observation of "Vision Transformers Need Registers" is that ViT models often leverage patches with low information (such as uniform background) to encode global information. Adding registers gives the model a better place to encode the global information.
In their implementation, a BOS token is used for final pooling. Since we have a captioner, the visual tokens are pooled through cross-attention. We decide to keep both the visual patches and the registers in the cross-attention.
In our CapPa model, we can visualize the attention from the image to any register:

In this example the rabbits have a stronger attention to "register 2" than the rest of the image.
We can also see the attention from the register to the image, which is probably more interesting.

Here we can see that "register 2" pays more attention to the rabbit ears.
The attention weights are all within a relatively bounded range without high noise to random patches on the image so we can expect that the registers are efficiently used to encode global information.
As a side note, we can automatically retrieve intermediate outputs in JAX/Flax using capture_intermediates.
output, mod_vars = model.apply({"params": params},**batch,# NOTE: added to capture intermediate valuescapture_intermediates=True, mutable=["intermediates"])attention_weights = mod_vars["intermediates"]["vision"]["encoder"]["layers"]["attention"]["attention_weights"]
In our examples we displayed the maps from first attention head. The norm and mean over attention heads is less interesting as each head attends to different parts of the image.
How to use the model
You can jit the function calls for faster inference as in the demo notebook. The model generation interface supports greedy search, temperature sampling with top_p/top_k, as well as beam search.
@partial(jax.jit,static_argnames=("num_beams", "do_sample", "temperature", "top_p", "top_k", "max_length", "num_return_sequences"),)def generate_caption(pixel_values, *args, **kwargs):return model.generate(pixel_values, *args, **kwargs)
Please keep in mind that the main purpose of the CapPa model is to only keep the vision tower and discard the text tower for downstream applications.
The model is also being ported to 🤗 Transformers.
Resources
- "Image Captioners Are Scalable Vision Learners Too", Michael Tschannen and Manoj Kumar and Andreas Steiner and Xiaohua Zhai and Neil Houlsby and Lucas Beyer
- "Sigmoid Loss for Language Image Pre-Training", Xiaohua Zhai and Basil Mustafa and Alexander Kolesnikov and Lucas Beyer
- "Vision Transformers Need Registers", Timothée Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski
- "SugarCrepe: Fixing Hackable Benchmarks for Vision-Language Compositionality", Cheng-Yu Hsieh and Jieyu Zhang and Zixian Ma and Aniruddha Kembhavi and Ranjay Krishna
Acknowledgements
- Weights & Biases for the infrastructure for experiment tracking and model management
- Hugging Face for the support during the project
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.