Table of Contents
Fetching ...

How Does Critical Batch Size Scale in Pre-training?

Hanlin Zhang, Depen Morwani, Nikhil Vyas, Jingfeng Wu, Difan Zou, Udaya Ghai, Dean Foster, Sham Kakade

TL;DR

This work formalizes the concept of critical batch size (CBS) in large-scale pre-training, and then empirically and theoretically analyzes how CBS scales with model size and data size. By training autoregressive transformers from 85M to 1.2B parameters on the C4 corpus and performing extensive hyperparameter sweeps, the authors show that CBS is primarily driven by data size, with model size playing a weaker role, especially under data-parallel compute-optimal conditions. They introduce a formal CBS definition, fit scaling laws (e.g., B^* ≈ 93.20 N^{0.47} for model size N), and validate these with controlled experiments decoupling data and model effects. Theoretical insights from infinite-width limits and a linear-regression analogue support the empirical findings, and the work highlights training strategies such as exponential weight averaging to study large-scale pre-training beyond fixed training durations. Practically, the results inform compute planning for pre-training by suggesting greater data parallelism and data-driven scaling without sacrificing efficiency.

Abstract

Training large-scale models under given resources requires careful design of parallelism strategies. In particular, the efficiency notion of critical batch size (CBS), concerning the compromise between time and compute, marks the threshold beyond which greater data parallelism leads to diminishing returns. To operationalize it, we propose a measure of CBS and pre-train a series of auto-regressive language models, ranging from 85 million to 1.2 billion parameters, on the C4 dataset. Through extensive hyper-parameter sweeps and careful control of factors such as batch size, momentum, and learning rate along with its scheduling, we systematically investigate the impact of scale on CBS. Then we fit scaling laws with respect to model and data sizes to decouple their effects. Overall, our results demonstrate that CBS scales primarily with data size rather than model size, a finding we justify theoretically through the analysis of infinite-width limits of neural networks and infinite-dimensional least squares regression. Of independent interest, we highlight the importance of common hyper-parameter choices and strategies for studying large-scale pre-training beyond fixed training durations.

How Does Critical Batch Size Scale in Pre-training?

TL;DR

This work formalizes the concept of critical batch size (CBS) in large-scale pre-training, and then empirically and theoretically analyzes how CBS scales with model size and data size. By training autoregressive transformers from 85M to 1.2B parameters on the C4 corpus and performing extensive hyperparameter sweeps, the authors show that CBS is primarily driven by data size, with model size playing a weaker role, especially under data-parallel compute-optimal conditions. They introduce a formal CBS definition, fit scaling laws (e.g., B^* ≈ 93.20 N^{0.47} for model size N), and validate these with controlled experiments decoupling data and model effects. Theoretical insights from infinite-width limits and a linear-regression analogue support the empirical findings, and the work highlights training strategies such as exponential weight averaging to study large-scale pre-training beyond fixed training durations. Practically, the results inform compute planning for pre-training by suggesting greater data parallelism and data-driven scaling without sacrificing efficiency.

Abstract

Training large-scale models under given resources requires careful design of parallelism strategies. In particular, the efficiency notion of critical batch size (CBS), concerning the compromise between time and compute, marks the threshold beyond which greater data parallelism leads to diminishing returns. To operationalize it, we propose a measure of CBS and pre-train a series of auto-regressive language models, ranging from 85 million to 1.2 billion parameters, on the C4 dataset. Through extensive hyper-parameter sweeps and careful control of factors such as batch size, momentum, and learning rate along with its scheduling, we systematically investigate the impact of scale on CBS. Then we fit scaling laws with respect to model and data sizes to decouple their effects. Overall, our results demonstrate that CBS scales primarily with data size rather than model size, a finding we justify theoretically through the analysis of infinite-width limits of neural networks and infinite-dimensional least squares regression. Of independent interest, we highlight the importance of common hyper-parameter choices and strategies for studying large-scale pre-training beyond fixed training durations.

Paper Structure

This paper contains 27 sections, 5 theorems, 23 equations, 14 figures, 7 tables.

Key Result

Theorem 1

In infinite width regimes yang21, training dynamics and performance of the networks become effectively independent of the model size. Consequently, the critical batch size remains nearly invariant when scaling up the model size beyond this point, indicating that larger models do not require proporti

Figures (14)

  • Figure 1: Optimization efficiency and scaling of critical batch size in Chinchilla (left) and controlled (middle, right) settings. To study the effect of CBS across different model sizes, we track the relative number of steps required to reach a certain target validation loss. In the Chinchilla setting (left), we keep the data-to-model size ratio $D/N=C_{\text{Chin}}$ constant and observe that CBS increases with scale. However, when controlling for either model size (middle) or data size (right), the growth in target losses becomes mostly dependent on data size rather than model size (\ref{['sec:scaling_laws']}).
  • Figure 2: Comparing and accounting for training dynamics. Throughout, we adopt Constant+EWA since it performs the best for large batch sizes and avoids setting a fixed training duration beforehand for reaching a target loss.
  • Figure 3: Scaling up width and depth shares similar efficiency gain for compute-optimal training.
  • Figure 4: Ablation results on context length using 151M models.
  • Figure 5: Illustration of critical batch size, where $B^*=2^{11.87}$ and context length is 512 by default.
  • ...and 9 more figures

Theorems & Definitions (9)

  • Theorem 1: Informal version of \ref{['thm:mup']}
  • Corollary 1: Informal version of \ref{['thm:cbs-linear-regression']}
  • Definition 1
  • Theorem 2
  • proof : Proof of \ref{['thm:mup']}
  • Theorem 3
  • Corollary 2
  • proof : Proof of \ref{['thm:mini-batch']}
  • proof : Proof of \ref{['thm:cbs-linear-regression']}