Table of Contents
Fetching ...

Scaling FP8 training to trillion-token LLMs

Maxim Fishman, Brian Chmiel, Ron Banner, Daniel Soudry

TL;DR

The paper demonstrates FP8 training at an unprecedented scale of 2 trillion tokens, uncovering instabilities caused by outlier amplification in the SwiGLU activation. It analyzes the weight-alignment dynamics within SwiGLU and introduces Smooth-SwiGLU to stabilize FP8 training without altering model behavior, alongside quantizing both Adam moments to FP8. By combining Smooth-SwiGLU with FP8 optimizer moments, the authors train a 7B model on 256 Gaudi2 accelerators with BF16-equivalent performance and up to 34% throughput gains. The work also provides reproducibility via a public codebase and discusses the environmental implications of reduced-precision training. Overall, the approach offers a viable, more efficient pathway for large-scale FP8 LLM training without sacrificing downstream quality.

Abstract

We train, for the first time, large language models using FP8 precision on datasets up to 2 trillion tokens -- a 20-fold increase over previous limits. Through these extended training runs, we uncover critical instabilities in FP8 training that were not observable in earlier works with shorter durations. We trace these instabilities to outlier amplification by the SwiGLU activation function. Interestingly, we show, both analytically and empirically, that this amplification happens only over prolonged training periods, and link it to a SwiGLU weight alignment process. To address this newly identified issue, we introduce Smooth-SwiGLU, a novel modification that ensures stable FP8 training without altering function behavior. We also demonstrate, for the first time, FP8 quantization of both Adam optimizer moments. Combining these innovations, we successfully train a 7B parameter model using FP8 precision on 256 Intel Gaudi2 accelerators, achieving on-par results with the BF16 baseline while delivering up to a $\sim 34 \%$ throughput improvement. A reference implementation is supplied in https://github.com/Anonymous1252022/Megatron-DeepSpeed.

Scaling FP8 training to trillion-token LLMs

TL;DR

The paper demonstrates FP8 training at an unprecedented scale of 2 trillion tokens, uncovering instabilities caused by outlier amplification in the SwiGLU activation. It analyzes the weight-alignment dynamics within SwiGLU and introduces Smooth-SwiGLU to stabilize FP8 training without altering model behavior, alongside quantizing both Adam moments to FP8. By combining Smooth-SwiGLU with FP8 optimizer moments, the authors train a 7B model on 256 Gaudi2 accelerators with BF16-equivalent performance and up to 34% throughput gains. The work also provides reproducibility via a public codebase and discusses the environmental implications of reduced-precision training. Overall, the approach offers a viable, more efficient pathway for large-scale FP8 LLM training without sacrificing downstream quality.

Abstract

We train, for the first time, large language models using FP8 precision on datasets up to 2 trillion tokens -- a 20-fold increase over previous limits. Through these extended training runs, we uncover critical instabilities in FP8 training that were not observable in earlier works with shorter durations. We trace these instabilities to outlier amplification by the SwiGLU activation function. Interestingly, we show, both analytically and empirically, that this amplification happens only over prolonged training periods, and link it to a SwiGLU weight alignment process. To address this newly identified issue, we introduce Smooth-SwiGLU, a novel modification that ensures stable FP8 training without altering function behavior. We also demonstrate, for the first time, FP8 quantization of both Adam optimizer moments. Combining these innovations, we successfully train a 7B parameter model using FP8 precision on 256 Intel Gaudi2 accelerators, achieving on-par results with the BF16 baseline while delivering up to a throughput improvement. A reference implementation is supplied in https://github.com/Anonymous1252022/Megatron-DeepSpeed.
Paper Structure (28 sections, 10 equations, 11 figures, 5 tables)

This paper contains 28 sections, 10 equations, 11 figures, 5 tables.

Figures (11)

  • Figure 1: Comparison of activation maximum values across different layers during 50 iterations of training: (a) At the beginning of training, showing stable maximum values. (b) After 200B tokens of training, revealing sporadic but significant outliers (notice the change in the z-axis scale).
  • Figure 2: (a): Training loss of LlaMA2-7b with BF16 and FP8 precision, where a significant loss divergence is seen for FP8 after step $\sim$ 200B tokens. (b): Dynamics of the $\mathbf{w}_1$ and $\mathbf{w}_2$ norms, and their correlation during training, for a specific channel that generates outliers. A drastic increase in correlation and norm is observed at the same point where we start to see loss degradation in (a). (c): Scatter plot of an outlier channel elements in $\mathbf{w}_1$ and $\mathbf{w}_2$, at an early training stage (8B tokens) and late training stage (330B tokens), demonstrating minimal correlation at start of the training and high correlation in the later stage. (d): Histogram of an outlier channel of $\mathbf{w}_1$ at an early training stage (8B tokens) and late training stage (330B tokens).
  • Figure 3: Training loss of Llama2 FP8 with and without quantization of SwiGLU output. As can be seen the cause of the divergence of standard FP8 is the amplification of the SwiGLU (input to $\mathbf{w}_3$).
  • Figure 4: A standard quantized MLP component containing the original quantized SwiGLU (a) and the proposed quantized Smooth-SwiGLU (b), which improves the stability under FP8 training. Here, $s$ is the scaling factor, $\hat{\mathbf{w}}_{1}$,$\hat{\mathbf{w}}_{2}$ and $\hat{\mathbf{w}}_{3}$ are the quantized weights, and $Q$ is the quantization function.
  • Figure 5: All combinations for quantization the Adam moments with standard FP8 formats in Llama2 100m. The only combination that is able to converge to baseline is first moment E4M3 format and second moment E5M2 format.
  • ...and 6 more figures