Table of Contents
Fetching ...

Critical Batch Size Revisited: A Simple Empirical Approach to Large-Batch Language Model Training

William Merrill, Shane Arora, Dirk Groeneveld, Hannaneh Hajishirzi

TL;DR

This work addresses the challenge of training large language models efficiently with large batch sizes by introducing an empirical critical batch size $B^*$ measured directly through local branched training, bypassing strong assumptions of gradient-noise proxies. It shows that $B^*$ rises quickly in early training and then plateaus (around $4096$ in their experiments) and is largely independent of model size, suggesting that CBS-guided batching can scale with data rather than model parameters. The authors validate a batch size warmup strategy that increases the batch size in step with the growing CBS, enabling larger throughput while maintaining or improving final loss; on OLMo 1B, this method achieves about 43% fewer gradient steps compared to a small-batch baseline with no loss degradation. Overall, the paper provides a simple, practical framework for measuring and leveraging CBS to improve efficiency in large-scale language-model pretraining, with potential extensions to online CBS estimation and broader architectures.

Abstract

The right batch size is important when training language models at scale: a large batch size is necessary for fast training, but a batch size that is too large will harm token efficiency. To navigate this tradeoff, McCandlish et al. (2018) suggest that a critical batch size (CBS), below which training will not substantially degrade loss, can be estimated based on the gradient noise scale during training. While their method has been adopted in practice, e.g., when training GPT-3, strong assumptions are required to justify gradient noise as a proxy for the CBS, which makes it unclear whether their approach should be trusted in practice, limiting its applicability. In this paper, we introduce a simple, empirical approach to directly measure the CBS and show how the CBS evolves over training. Applying our approach to the OLMo models, we find that CBS is near 0 at initialization, increases rapidly at first, and then plateaus as training progresses. Furthermore, we find that this trend holds across different model sizes (1B and 7B), suggesting CBS from small training runs can inform larger-scale training runs. Our findings about how the CBS changes over training motivate batch size warmup as a natural way to reliably train language models at large batch size: start the batch size small and increase it as the CBS grows. To validate this claim, we use batch size warmup to train OLMo 1B to slightly better loss than the original training run with 43% fewer gradient steps. This shows how our framework can be applied to reliably train language models at larger batch sizes, increasing data parallelism without compromising performance.

Critical Batch Size Revisited: A Simple Empirical Approach to Large-Batch Language Model Training

TL;DR

This work addresses the challenge of training large language models efficiently with large batch sizes by introducing an empirical critical batch size measured directly through local branched training, bypassing strong assumptions of gradient-noise proxies. It shows that rises quickly in early training and then plateaus (around in their experiments) and is largely independent of model size, suggesting that CBS-guided batching can scale with data rather than model parameters. The authors validate a batch size warmup strategy that increases the batch size in step with the growing CBS, enabling larger throughput while maintaining or improving final loss; on OLMo 1B, this method achieves about 43% fewer gradient steps compared to a small-batch baseline with no loss degradation. Overall, the paper provides a simple, practical framework for measuring and leveraging CBS to improve efficiency in large-scale language-model pretraining, with potential extensions to online CBS estimation and broader architectures.

Abstract

The right batch size is important when training language models at scale: a large batch size is necessary for fast training, but a batch size that is too large will harm token efficiency. To navigate this tradeoff, McCandlish et al. (2018) suggest that a critical batch size (CBS), below which training will not substantially degrade loss, can be estimated based on the gradient noise scale during training. While their method has been adopted in practice, e.g., when training GPT-3, strong assumptions are required to justify gradient noise as a proxy for the CBS, which makes it unclear whether their approach should be trusted in practice, limiting its applicability. In this paper, we introduce a simple, empirical approach to directly measure the CBS and show how the CBS evolves over training. Applying our approach to the OLMo models, we find that CBS is near 0 at initialization, increases rapidly at first, and then plateaus as training progresses. Furthermore, we find that this trend holds across different model sizes (1B and 7B), suggesting CBS from small training runs can inform larger-scale training runs. Our findings about how the CBS changes over training motivate batch size warmup as a natural way to reliably train language models at large batch size: start the batch size small and increase it as the CBS grows. To validate this claim, we use batch size warmup to train OLMo 1B to slightly better loss than the original training run with 43% fewer gradient steps. This shows how our framework can be applied to reliably train language models at larger batch sizes, increasing data parallelism without compromising performance.

Paper Structure

This paper contains 29 sections, 3 theorems, 14 equations, 7 figures, 2 tables.

Key Result

Proposition 1

Let $f(t)$ be integrable with $f(0) = 0$ and define Then $R_2$ is minimized by $B^* = \frac{1}{T} \int_0^T f(t) \mathrm{d} t$.

Figures (7)

  • Figure 1: Smoothed final loss after branched training at particular checkpoints, with $B^*$ shown as the dotted red line. Each point represents the loss achieved by a specific branched training run after 2B tokens. Our method detects the point at which loss starts to increase, heuristically tolerating noise within $\epsilon = 0.01$. These plots show how this plays out for three particular checkpoints; see \ref{['sec:cbs-measurement-details']} for loss curves for all checkpoints.
  • Figure 2: CBS over training for OLMo 1B and 7B, measured in documents (4096 tokens per document). The qualitative trend is similar across both model sizes. The CBS starts near 0, grows rapidly but diminishingly, and plateaus around 4096.
  • Figure 3: Gradient noise scale for OLMo 1B and 7B computed via the estimator of mccandlish2018empirical with 95% confidence intervals; details in \ref{['app:noise-scale']}. The gradient noise scale underestimates the CBS (cf. \ref{['fig:measure-cbs']}) and the qualitative trend does not clearly match, especially for OLMo 7B.
  • Figure 4: Batch size schedule (left, top), learning rate schedule (left, bottom) and training loss (right) for the pretraining of an OLMo model with different batch size schedules. Training loss is smoothed by taking the moving average over the past 10B tokens.
  • Figure 5: All loss vs. batch size plots for OLMo 1B. Overall, the red line moves to the right over time, showing that the CBS increases.
  • ...and 2 more figures

Theorems & Definitions (6)

  • Proposition 1
  • proof
  • Proposition 2: $B^*$ for power-law CBS
  • proof
  • Proposition 3: $B^*$ for log CBS
  • proof