Pre-training under infinite compute
Created on September 15|Last edited on September 17
Comment
We open-source the logs for the ~2000 runs we perform on Marin (testing epoching + parameter scaling + regularization + ensembling + distillation + continued pre-training). We conduct scaling analysis and downstream benchmark evaluation externally. We hope that releasing these logs allows the community to analyze our collection of pre-training runs for interesting insights. We also note that many runs to build intuition were conducted at a much smaller scale before we confirmed our intuitions in Marin.
Abstract: Since compute grows much faster than internet text available for language model pre-training, we ask how we should approach pre-training under fixed data and no compute constraints. We first show that existing data-constrained approaches of increasing epoch count and parameter count eventually overfit, and we significantly improve upon such recipes by properly tuning regularization, finding that the optimal weight decay is $30\times$ larger than standard practice. Since our regularized recipe monotonically decreases loss following a simple power law in parameter count, we estimate its best possible performance via the \textbf{asymptote} of its scaling law rather than the performance at a fixed compute budget. We then identify that ensembling independently trained models achieves a significantly lower loss asymptote than the regularized recipe. Our best intervention combining epoching, regularization, parameter scaling, and ensemble scaling achieves an asymptote at 200M tokens that existing approaches need $5.17\times$ more data to match, and our data scaling laws predict that this improvement persists at higher token budgets. These data-efficiency gains are not solely due to increasing total parameter count, as we find that we can distill an ensemble into a $8\times$ smaller student model that retains $83\%$ of the ensembling benefit. Finally, our interventions designed for validation loss generalize to downstream benchmarks, achieving a $9\%$ improvement for pre-training evals and a $17.5\times$ data-efficiency improvement over continued pre-training on math mid-training data. Our results show that simple algorithmic improvements can enable significantly higher data-efficiency in a compute-rich future.
Section 2: Existing approaches (epoching, parameter scaling)Section 3: Regularized parameter scalingSection 4: EnsemblesSection 6: DistillationSection 7.2: Continued pre-training
Each pre-training batch has 0.25M tokens, so 800 batches = 200M tokens, etc.
Section 2: Existing approaches (epoching, parameter scaling)
Tuning epoch count and parameter count
64
Section 3: Regularized parameter scaling
Tuning weight decay
686
Section 4: Ensembles
Ensembles
338
Section 6: Distillation
300M baseline
1
Teacher run
1
8-ensemble distill
15
Self-distill
11
Section 7.2: Continued pre-training
NOTE: The accuracies reported here are inconsistent with lm-eval-harness. All numbers reported in the paper come from using the lm-eval-harness for ensembles and single models implemented here: github.com/konwook/lm-eval-ensemble
default 4B
2
our single model
14
default 73B
1
Add a comment