Skip to main content

New Data Selection Method Speeds Up Contrastive Learning by 10x

Deepmind introduces JEST for training models 13x faster!
Created on July 10|Last edited on July 10
In the world of machine learning, data quality plays a crucial role in determining model performance. However, manually curating large-scale datasets is labor-intensive and doesn't scale well. Enter JEST (Joint Example Selection), a groundbreaking approach that significantly accelerates multimodal learning by optimizing the data selection process.

Traditional Data Selection

Traditional data selection methods focus on individual examples, often neglecting the interplay between data points. This approach can lead to inefficient learning, especially in multimodal settings where the relationship between different types of data (e.g., text and images) is key. JEST addresses this by evaluating and selecting data in batches, considering the dependencies and interactions within these batches.

Background on JEST

JEST employs two models during training: a learner model and a reference model. The learner model is the primary model being trained, starting with little knowledge and improving over time. In contrast, the reference model is a pretrained model that has been trained on a smaller, well-curated dataset. This reference model, having learned from high-quality data, serves as a benchmark to guide the selection of the most useful data for the learner model.

Scoring

The scoring mechanism of JEST is at the heart of its effectiveness. The learner model loss indicates how well the current model is performing on each example, with a higher loss signifying difficulty. On the other hand, the reference model loss reflects how well the pretrained reference model performs, with a lower loss indicating higher quality. By combining these two metrics, JEST calculates a learnability score that prioritizes examples that are difficult for the learner yet easy for the reference model. This ensures that the selected data is both challenging and beneficial for the learner.

Batch Selection

The batch selection process in JEST starts with the formation of a large pool of data, known as a super-batch. Initial scoring involves calculating the scores for all examples using both the learner and reference models. The selection then proceeds iteratively. Initially, the most "learnable" examples are picked, particularly those easy for the reference model. Subsequently, more examples are added in chunks, each time selecting those that balance being challenging for the learner and easy for the reference model. This dynamic approach continuously optimizes the batch composition.
To further enhance efficiency, JEST introduces Flexi-JEST, which employs multi-resolution training. This method splits the training batch into high-resolution and low-resolution subsets. By using approximate models for scoring, Flexi-JEST reduces computational costs while maintaining performance. This method ensures that the overall training process remains efficient without compromising the quality of the data selected.

Huge Speedup

In practical terms, JEST significantly accelerates training. By focusing on the most valuable data, it can achieve state-of-the-art results with up to 13× fewer iterations and 10× less computation. This efficiency makes JEST particularly valuable for large-scale multimodal learning tasks, such as training vision-language models.
JEST revolutionizes data curation by leveraging the power of joint example selection and pretrained reference models. By dynamically optimizing batch selection, it ensures that every bit of training effort is maximized. For researchers and practitioners using Weights & Biases to track and optimize their machine learning experiments, incorporating JEST can lead to faster, more efficient training and better overall performance. This innovative approach highlights the importance of smart data selection in pushing the boundaries of what's possible in AI.
Tags: ML News
Iterate on AI agents and models faster. Try Weights & Biases today.