Table of Contents
Fetching ...

Why Low-Precision Transformer Training Fails: An Analysis on Flash Attention

Haiquan Qiu, Quanming Yao

TL;DR

The paper addresses why low-precision training of transformers with flash attention can catastrophically fail, revealing a causal mechanism built from two interacting factors: emergent similar low-rank gradient updates and biased BF16 rounding errors that accumulate over training. By tracing the instability to the backward path term ${\boldsymbol{\delta}} = \mathrm{rowsum}(d{\mathbf{O}} \circ {\mathbf{O}})$ and the BF16 computation of ${\mathbf{O}}$, the authors demonstrate that biased rounding in the forward path induces a persistent positive δ-difference that biases weight updates through a common low-rank structure ${\mathbf{R}}$. They validate this picture with a minimal, targeted fix: modifying the softmax normalization to prevent any ${\bar{\mathbf{P}}}$ entry from reaching 1, which eliminates the biased rounding and stabilizes training without changing the backward path. The practical impact is a simple, hardware-agnostic adjustment that enables robust, memory-efficient low-precision transformer training, with code available at the linked repository. Overall, the work provides a principled route to diagnose and remediate numerical instabilities in efficient transformer training at scale.

Abstract

The pursuit of computational efficiency has driven the adoption of low-precision formats for training transformer models. However, this progress is often hindered by notorious training instabilities. This paper provides the first mechanistic explanation for a long-standing and unresolved failure case where training with flash attention in low-precision settings leads to catastrophic loss explosion. Our in-depth analysis reveals that the failure is not a random artifact but caused by two intertwined phenomena: the emergence of similar low-rank representations within the attention mechanism and the compounding effect of biased rounding errors inherent in low-precision arithmetic. We demonstrate how these factors create a vicious cycle of error accumulation that corrupts weight updates, ultimately derailing the training dynamics. To validate our findings, we introduce a minimal modification to the flash attention that mitigates the bias in rounding errors. This simple change stabilizes the training process, confirming our analysis and offering a practical solution to this persistent problem. Code is available at https://github.com/ucker/why-low-precision-training-fails.

Why Low-Precision Transformer Training Fails: An Analysis on Flash Attention

TL;DR

The paper addresses why low-precision training of transformers with flash attention can catastrophically fail, revealing a causal mechanism built from two interacting factors: emergent similar low-rank gradient updates and biased BF16 rounding errors that accumulate over training. By tracing the instability to the backward path term and the BF16 computation of , the authors demonstrate that biased rounding in the forward path induces a persistent positive δ-difference that biases weight updates through a common low-rank structure . They validate this picture with a minimal, targeted fix: modifying the softmax normalization to prevent any entry from reaching 1, which eliminates the biased rounding and stabilizes training without changing the backward path. The practical impact is a simple, hardware-agnostic adjustment that enables robust, memory-efficient low-precision transformer training, with code available at the linked repository. Overall, the work provides a principled route to diagnose and remediate numerical instabilities in efficient transformer training at scale.

Abstract

The pursuit of computational efficiency has driven the adoption of low-precision formats for training transformer models. However, this progress is often hindered by notorious training instabilities. This paper provides the first mechanistic explanation for a long-standing and unresolved failure case where training with flash attention in low-precision settings leads to catastrophic loss explosion. Our in-depth analysis reveals that the failure is not a random artifact but caused by two intertwined phenomena: the emergence of similar low-rank representations within the attention mechanism and the compounding effect of biased rounding errors inherent in low-precision arithmetic. We demonstrate how these factors create a vicious cycle of error accumulation that corrupts weight updates, ultimately derailing the training dynamics. To validate our findings, we introduce a minimal modification to the flash attention that mitigates the bias in rounding errors. This simple change stabilizes the training process, confirming our analysis and offering a practical solution to this persistent problem. Code is available at https://github.com/ucker/why-low-precision-training-fails.

Paper Structure

This paper contains 36 sections, 12 equations, 10 figures, 3 algorithms.

Figures (10)

  • Figure 1: Analysis in different sections. Our paper traces the causal chain of training failure (blue box) in reverse to identify the root causes.
  • Figure 2: The failure case using BF16 and flash attention results in a sudden loss explosion, while the stable configuration converges.
  • Figure 3: ${\mathbf{W}}^Q$ of attention head 8 has the largest spectral norm. Subsequent analysis focuses on this head.
  • Figure 4: ${\mathbf{P}} {\mathbf{K}}$, ${\mathbf{X}}$, and $({\mathbf{P}} {\mathbf{K}})[T]^\top {\mathbf{X}}[T]$ at different batch indices and training steps. (c) and (f) show that $({\mathbf{P}} {\mathbf{K}})[T]^\top {\mathbf{X}}[T]$ for different tokens and training steps have some similar columns in input features 546 and 678.
  • Figure 5: Analysis of ${\boldsymbol{\delta}}=\mathrm{rowsum}(d{\mathbf{O}} \circ {\mathbf{O}})$.
  • ...and 5 more figures

Theorems & Definitions (4)

  • Claim 1
  • Claim 2
  • Remark 1
  • Claim 3