Table of Contents
Fetching ...

Adaptive Batch Size Schedules for Distributed Training of Language Models with Data and Model Parallelism

Tim Tsz-Kit Lau, Weijian Li, Chenwei Xu, Han Liu, Mladen Kolar

TL;DR

The paper tackles the challenge of selecting adaptive batch sizes for pretraining large language models under distributed data and model parallelism. It develops a norm-test based adaptive batching framework (DDP-Norm and FSDP-Norm) that automatically scales batch size in response to gradient-noise estimates, with convergence guarantees for Adam. Empirically, it demonstrates improvements over constant-size and warmup schedules across MicroLlama, TinyLlama, and OpenLlama on the C4 dataset, while enabling training with practical hardware using PyTorch FSDP. The work provides both theoretical and practical contributions, offering a principled path to more efficient and stable large-scale pretraining and highlighting avenues for extending to multi-dimensional parallelism and downstream tasks.

Abstract

An appropriate choice of batch sizes in large-scale model training is crucial, yet it involves an intrinsic yet inevitable dilemma: large-batch training improves training efficiency in terms of memory utilization, while generalization performance often deteriorates due to small amounts of gradient noise. Despite this dilemma, the common practice of choosing batch sizes in language model training often prioritizes training efficiency -- employing either constant large sizes with data parallelism or implementing batch size warmup schedules. However, such batch size schedule designs remain heuristic and often fail to adapt to training dynamics, presenting the challenge of designing adaptive batch size schedules. Given the abundance of available datasets and the data-hungry nature of language models, data parallelism has become an indispensable distributed training paradigm, enabling the use of larger batch sizes for gradient computation. However, vanilla data parallelism requires replicas of model parameters, gradients, and optimizer states at each worker, which prohibits training larger models with billions of parameters. To optimize memory usage, more advanced parallelism strategies must be employed. In this work, we propose general-purpose and theoretically principled adaptive batch size schedules compatible with data parallelism and model parallelism. We develop a practical implementation with PyTorch Fully Sharded Data Parallel, facilitating the pretraining of language models of different sizes. We empirically demonstrate that our proposed approaches outperform constant batch sizes and heuristic batch size warmup schedules in the pretraining of models in the Llama 2 family, with particular focus on smaller models with up to 3 billion parameters. We also establish theoretical convergence guarantees for such adaptive batch size schedules with Adam for general smooth nonconvex objectives.

Adaptive Batch Size Schedules for Distributed Training of Language Models with Data and Model Parallelism

TL;DR

The paper tackles the challenge of selecting adaptive batch sizes for pretraining large language models under distributed data and model parallelism. It develops a norm-test based adaptive batching framework (DDP-Norm and FSDP-Norm) that automatically scales batch size in response to gradient-noise estimates, with convergence guarantees for Adam. Empirically, it demonstrates improvements over constant-size and warmup schedules across MicroLlama, TinyLlama, and OpenLlama on the C4 dataset, while enabling training with practical hardware using PyTorch FSDP. The work provides both theoretical and practical contributions, offering a principled path to more efficient and stable large-scale pretraining and highlighting avenues for extending to multi-dimensional parallelism and downstream tasks.

Abstract

An appropriate choice of batch sizes in large-scale model training is crucial, yet it involves an intrinsic yet inevitable dilemma: large-batch training improves training efficiency in terms of memory utilization, while generalization performance often deteriorates due to small amounts of gradient noise. Despite this dilemma, the common practice of choosing batch sizes in language model training often prioritizes training efficiency -- employing either constant large sizes with data parallelism or implementing batch size warmup schedules. However, such batch size schedule designs remain heuristic and often fail to adapt to training dynamics, presenting the challenge of designing adaptive batch size schedules. Given the abundance of available datasets and the data-hungry nature of language models, data parallelism has become an indispensable distributed training paradigm, enabling the use of larger batch sizes for gradient computation. However, vanilla data parallelism requires replicas of model parameters, gradients, and optimizer states at each worker, which prohibits training larger models with billions of parameters. To optimize memory usage, more advanced parallelism strategies must be employed. In this work, we propose general-purpose and theoretically principled adaptive batch size schedules compatible with data parallelism and model parallelism. We develop a practical implementation with PyTorch Fully Sharded Data Parallel, facilitating the pretraining of language models of different sizes. We empirically demonstrate that our proposed approaches outperform constant batch sizes and heuristic batch size warmup schedules in the pretraining of models in the Llama 2 family, with particular focus on smaller models with up to 3 billion parameters. We also establish theoretical convergence guarantees for such adaptive batch size schedules with Adam for general smooth nonconvex objectives.
Paper Structure (32 sections, 7 theorems, 21 equations, 4 figures, 7 tables, 1 algorithm)

This paper contains 32 sections, 7 theorems, 21 equations, 4 figures, 7 tables, 1 algorithm.

Key Result

Proposition 1

The coordinate-wise (exact variance) norm test with constant $\eta\in(0,1)$ ensures that, for every iteration $k\in\llbracket K\rrbracket$, the coordinate-wise batch gradient $\partial_i \mathscr{L}_{\mathcal{B}_k}(w_k)$ satisfies the following coordinate-wise expected strong growth (E-SG) condition

Figures (4)

  • Figure 1: Generalization gap in transformer pretraining. Various curves represent distinct batch sizes.
  • Figure 2: Training loss, validation loss and batch size schedule for MicroLlama 300M
  • Figure 3: Training loss, validation loss and batch size schedule for TinyLlama 1.1B
  • Figure 4: Training loss, validation loss and batch size schedule for OpenLlama 3B

Theorems & Definitions (12)

  • Proposition 1
  • Theorem 1
  • Remark B.1
  • Lemma B.1
  • Lemma B.2
  • proof : Proof Sketch
  • Theorem B.1: Formal version of \ref{['thm:adam']}
  • Lemma B.3
  • proof : Proof Sketch
  • Lemma B.4
  • ...and 2 more