Table of Contents
Fetching ...

Is Flash Attention Stable?

Alicia Golden, Samuel Hsia, Fei Sun, Bilge Acun, Basil Hosmer, Yejin Lee, Zachary DeVito, Jeff Johnson, Gu-Yeon Wei, David Brooks, Carole-Jean Wu

TL;DR

Large-scale model training can exhibit instability potentially linked to numeric deviation. The paper introduces a two-phase framework—microbenchmark perturbations and Wasserstein-distance weight analysis—and applies it to Flash Attention to quantify deviation and bound downstream effects. It demonstrates that Flash Attention exhibits notable numeric deviation at low precision but that downstream weight changes are modest and bounded, roughly 2-5x smaller than low-precision training. The results provide a principled method to quantify and contextualize training optimizations, informing safer deployment of attention-acceleration techniques.

Abstract

Training large-scale machine learning models poses distinct system challenges, given both the size and complexity of today's workloads. Recently, many organizations training state-of-the-art Generative AI models have reported cases of instability during training, often taking the form of loss spikes. Numeric deviation has emerged as a potential cause of this training instability, although quantifying this is especially challenging given the costly nature of training runs. In this work, we develop a principled approach to understanding the effects of numeric deviation, and construct proxies to put observations into context when downstream effects are difficult to quantify. As a case study, we apply this framework to analyze the widely-adopted Flash Attention optimization. We find that Flash Attention sees roughly an order of magnitude more numeric deviation as compared to Baseline Attention at BF16 when measured during an isolated forward pass. We then use a data-driven analysis based on the Wasserstein Distance to provide upper bounds on how this numeric deviation impacts model weights during training, finding that the numerical deviation present in Flash Attention is 2-5 times less significant than low-precision training.

Is Flash Attention Stable?

TL;DR

Large-scale model training can exhibit instability potentially linked to numeric deviation. The paper introduces a two-phase framework—microbenchmark perturbations and Wasserstein-distance weight analysis—and applies it to Flash Attention to quantify deviation and bound downstream effects. It demonstrates that Flash Attention exhibits notable numeric deviation at low precision but that downstream weight changes are modest and bounded, roughly 2-5x smaller than low-precision training. The results provide a principled method to quantify and contextualize training optimizations, informing safer deployment of attention-acceleration techniques.

Abstract

Training large-scale machine learning models poses distinct system challenges, given both the size and complexity of today's workloads. Recently, many organizations training state-of-the-art Generative AI models have reported cases of instability during training, often taking the form of loss spikes. Numeric deviation has emerged as a potential cause of this training instability, although quantifying this is especially challenging given the costly nature of training runs. In this work, we develop a principled approach to understanding the effects of numeric deviation, and construct proxies to put observations into context when downstream effects are difficult to quantify. As a case study, we apply this framework to analyze the widely-adopted Flash Attention optimization. We find that Flash Attention sees roughly an order of magnitude more numeric deviation as compared to Baseline Attention at BF16 when measured during an isolated forward pass. We then use a data-driven analysis based on the Wasserstein Distance to provide upper bounds on how this numeric deviation impacts model weights during training, finding that the numerical deviation present in Flash Attention is 2-5 times less significant than low-precision training.
Paper Structure (11 sections, 8 figures)

This paper contains 11 sections, 8 figures.

Figures (8)

  • Figure 1: Flash Attention Tiling Operation. Flash Attention uses tiling and recomputation to eliminate the need for the large $N\times N$ similarity matrix. The output of the $QK^T$ dot product is instead calculated by blocks, as shown in green.
  • Figure 2: Experimental Methodology. (1) We implement a numerical microbenchmark of the Flash Attention operation, which allows for the experimentation of different numerical precisions, as well as the testing of various optimizations throughout the algorithm. Our framework allows for the direct comparison of the Attention Matrix output between Baseline Attention, Flash Attention, and our numeric re-implementation. (2) We utilize a data-driven procedure to contextualize this numeric difference via examining model weight changes over the course of training.
  • Figure 3: Sweep of numeric precision reveals that there exists a numerical difference between Flash Attention and Baseline Attention, and this varies with numerical precision. As the number format changes from BF16 to FP64, the numeric deviation between Flash Attention and Baseline Attention decreases.
  • Figure 4: Comparison of Flash Attention at different number formats to Golden Value of Baseline Attention at FP64. We find that Flash Attention sees roughly 10x more numeric deviation as compared to Baseline Attention at BF16.
  • Figure 5: Impact of Sequence Length on Numerical Deviation of Flash Attention. (a) Increasing sequence length increases maximum difference between Attention matrix outputs (b) Difference between attention output distributions measured with mean and standard deviation reflects a similar trend.
  • ...and 3 more figures