Why Low-Precision Transformer Training Fails: An Analysis on Flash Attention
Haiquan Qiu, Quanming Yao
TL;DR
The paper addresses why low-precision training of transformers with flash attention can catastrophically fail, revealing a causal mechanism built from two interacting factors: emergent similar low-rank gradient updates and biased BF16 rounding errors that accumulate over training. By tracing the instability to the backward path term ${\boldsymbol{\delta}} = \mathrm{rowsum}(d{\mathbf{O}} \circ {\mathbf{O}})$ and the BF16 computation of ${\mathbf{O}}$, the authors demonstrate that biased rounding in the forward path induces a persistent positive δ-difference that biases weight updates through a common low-rank structure ${\mathbf{R}}$. They validate this picture with a minimal, targeted fix: modifying the softmax normalization to prevent any ${\bar{\mathbf{P}}}$ entry from reaching 1, which eliminates the biased rounding and stabilizes training without changing the backward path. The practical impact is a simple, hardware-agnostic adjustment that enables robust, memory-efficient low-precision transformer training, with code available at the linked repository. Overall, the work provides a principled route to diagnose and remediate numerical instabilities in efficient transformer training at scale.
Abstract
The pursuit of computational efficiency has driven the adoption of low-precision formats for training transformer models. However, this progress is often hindered by notorious training instabilities. This paper provides the first mechanistic explanation for a long-standing and unresolved failure case where training with flash attention in low-precision settings leads to catastrophic loss explosion. Our in-depth analysis reveals that the failure is not a random artifact but caused by two intertwined phenomena: the emergence of similar low-rank representations within the attention mechanism and the compounding effect of biased rounding errors inherent in low-precision arithmetic. We demonstrate how these factors create a vicious cycle of error accumulation that corrupts weight updates, ultimately derailing the training dynamics. To validate our findings, we introduce a minimal modification to the flash attention that mitigates the bias in rounding errors. This simple change stabilizes the training process, confirming our analysis and offering a practical solution to this persistent problem. Code is available at https://github.com/ucker/why-low-precision-training-fails.
