Table of Contents
Fetching ...

Stochastic Rounding for LLM Training: Theory and Practice

Kaan Ozkara, Tao Yu, Youngsuk Park

TL;DR

This work studies stochastic rounding (SR) as a low-precision training technique for very large language models and analyzes its theoretical and practical implications when used with the AdamW optimizer. It introduces BF16-AdamW-SR, extends SR to distributed settings with shared randomness, and demonstrates that BF16+SR can surpass traditional BF16/FP32 mixed-precision training in both perplexity and efficiency, achieving up to $1.54\times$ higher throughput and up to $30\%$ memory savings on models up to $6.7$B parameters. The authors establish that SR induces implicit regularization in the loss and provide convergence bounds for Adam with SR that can subsume quantization error under appropriate hyper-parameter choices, particularly high learning rates. Empirically, BF16+SR matches or improves perplexity while delivering substantial speedups and memory reductions, and the approach generalizes to very large model scales with minimal overhead, offering a practical path toward efficient distributed LLM training.

Abstract

As the parameters of Large Language Models (LLMs) have scaled to hundreds of billions, the demand for efficient training methods -- balancing faster computation and reduced memory usage without sacrificing accuracy -- has become more critical than ever. In recent years, various mixed precision strategies, which involve different precision levels for optimization components, have been proposed to increase training speed with minimal accuracy degradation. However, these strategies often require manual adjustments and lack theoretical justification. In this work, we leverage stochastic rounding (SR) to address numerical errors of training with low-precision representation. We provide theoretical analyses of implicit regularization and convergence under the Adam optimizer when SR is utilized. With the insights from these analyses, we extend previous BF16 + SR strategy to be used in distributed settings, enhancing the stability and performance for large scale training. Empirical results from pre-training models with up to 6.7B parameters, for the first time, demonstrate that our BF16 with SR strategy outperforms (BF16, FP32) mixed precision strategies, achieving better validation perplexity, up to $1.54\times$ higher throughput, and $30\%$ less memory usage.

Stochastic Rounding for LLM Training: Theory and Practice

TL;DR

This work studies stochastic rounding (SR) as a low-precision training technique for very large language models and analyzes its theoretical and practical implications when used with the AdamW optimizer. It introduces BF16-AdamW-SR, extends SR to distributed settings with shared randomness, and demonstrates that BF16+SR can surpass traditional BF16/FP32 mixed-precision training in both perplexity and efficiency, achieving up to higher throughput and up to memory savings on models up to B parameters. The authors establish that SR induces implicit regularization in the loss and provide convergence bounds for Adam with SR that can subsume quantization error under appropriate hyper-parameter choices, particularly high learning rates. Empirically, BF16+SR matches or improves perplexity while delivering substantial speedups and memory reductions, and the approach generalizes to very large model scales with minimal overhead, offering a practical path toward efficient distributed LLM training.

Abstract

As the parameters of Large Language Models (LLMs) have scaled to hundreds of billions, the demand for efficient training methods -- balancing faster computation and reduced memory usage without sacrificing accuracy -- has become more critical than ever. In recent years, various mixed precision strategies, which involve different precision levels for optimization components, have been proposed to increase training speed with minimal accuracy degradation. However, these strategies often require manual adjustments and lack theoretical justification. In this work, we leverage stochastic rounding (SR) to address numerical errors of training with low-precision representation. We provide theoretical analyses of implicit regularization and convergence under the Adam optimizer when SR is utilized. With the insights from these analyses, we extend previous BF16 + SR strategy to be used in distributed settings, enhancing the stability and performance for large scale training. Empirical results from pre-training models with up to 6.7B parameters, for the first time, demonstrate that our BF16 with SR strategy outperforms (BF16, FP32) mixed precision strategies, achieving better validation perplexity, up to higher throughput, and less memory usage.

Paper Structure

This paper contains 24 sections, 12 theorems, 53 equations, 8 figures, 8 tables, 1 algorithm.

Key Result

Proposition 1

Assume that gradient estimated at dimension $i\in\{1,\ldots,d\}$ and $t\in[1,\ldots, T]$ is composed of two Bernoulli parts $g[i,t] = g^{(1)}[i] + g^{(2)}[i,t]$, where $g^{(1)}[i] \sim \rho \times \text{Bernoulli}(\frac{1}{2})$ and $g^{(2)}[i,t] \sim \text{Bernoulli}(\frac{1}{2})$ are independent. F

Figures (8)

  • Figure 2: Depiction of updates when full precision and quantized stochastic rounding updates are used in the toy example.
  • Figure 3: Position of $x$ vs. number of iterations for FP32 updates (Left) and BF16+SR updates (Right), each position curve represents the position vs iterations for an individual run. With the decaying updates SR can take a longer time to converge, here FP32 converges in 4600 steps and SR converges in on average 7700 (over 100 runs). We observe that an individual run with SR can converge in as much as 40,000 iterations.
  • Figure 4: Training and validation loss curves comparing FP32 training and BF16+SR strategy while training GPT-2 (350M).
  • Figure 5: Validation (left) and training losses (middle--zoomed out, right--zoomed in) for training GPT-Neo (6.7B).
  • Figure 6: Validation (left) and training losses (right) for training GPT-Neo (2.7B).
  • ...and 3 more figures

Theorems & Definitions (18)

  • Proposition 1
  • Lemma 1: Effective gradient perturbation
  • Theorem 1
  • Theorem 2: No momentum, $\beta_1=0$
  • Corollary 1: Comparison to full precision Adam
  • Theorem 3
  • proof : Proof of Lemma \ref{['lemma:xi']}
  • Theorem : Restatement of \ref{['thm:implicit reg']}.
  • proof
  • Theorem : Restated \ref{['thm:convergence']}
  • ...and 8 more