Table of Contents
Fetching ...

Rank-Aware Spectral Bounds on Attention Logits for Stable Low-Precision Training

Seyed Morteza Emadi

TL;DR

The paper introduces a rank-aware, geometry-driven calibration framework to stabilize FP8 training for transformers by bounding attention logits with a per-layer spectral-norm-based scale. It derives a tighter, rank-aware probabilistic guarantee for overflow risk, enabling principled selection of a calibration factor $\alpha$ and a per-layer scale $\text{scale}^{(\ell)}$ that adapts to weight geometry. An efficient, implicit power-iteration procedure estimates the interaction matrix's spectral norm without forming large matrices, supporting grouped query attention and RoPE extensions, and a memory-safe auto-$\alpha$ scheme tunes utilization during steady-state training. Empirical results across GPT-2 XL to Llama-2-70B demonstrate zero overflows in transient scenarios while maintaining comparable MMLU accuracy, with modest overhead and improved FP8 dynamic range usage. The approach offers a practical, theoretically grounded path to reliable low-precision transformer training that remains compatible with fused attention kernels.

Abstract

Attention scores in transformers are bilinear forms $S_{ij} = x_i^\top M x_j / \sqrt{d_h}$ whose maximum magnitude governs overflow risk in low-precision training. We derive a \emph{rank-aware concentration inequality}: when the interaction matrix $M = W^Q W^{K\top}$ has rank $r \ll d$, tail probabilities for $\max_{i,j}|S_{ij}|$ decay as $\exp(-d^{2}α^{2}/(γr))$ rather than $\exp(-dα^{2})$, where $γ> 1$ is a typicality parameter. For transformer attention where $r = d_h$, this yields $8$--$28\times$ tighter concentration than rank-agnostic bounds in modern architectures. We apply this result to FP8 training, deriving \emph{geometry-aware scale factors} that provide principled overflow guarantees without observing activations. The method computes per-layer scales from the spectral norm $\|W^Q W^{K\top}\|_2$ via implicit power iteration, includes a grouped query attention formulation that avoids key expansion, and remains compatible with fused attention kernels. Across GPT-2 XL to Llama-2-70B, geometry-aware scaling eliminates overflows in transient scenarios where delayed scaling fails, while achieving comparable downstream MMLU accuracy.

Rank-Aware Spectral Bounds on Attention Logits for Stable Low-Precision Training

TL;DR

The paper introduces a rank-aware, geometry-driven calibration framework to stabilize FP8 training for transformers by bounding attention logits with a per-layer spectral-norm-based scale. It derives a tighter, rank-aware probabilistic guarantee for overflow risk, enabling principled selection of a calibration factor and a per-layer scale that adapts to weight geometry. An efficient, implicit power-iteration procedure estimates the interaction matrix's spectral norm without forming large matrices, supporting grouped query attention and RoPE extensions, and a memory-safe auto- scheme tunes utilization during steady-state training. Empirical results across GPT-2 XL to Llama-2-70B demonstrate zero overflows in transient scenarios while maintaining comparable MMLU accuracy, with modest overhead and improved FP8 dynamic range usage. The approach offers a practical, theoretically grounded path to reliable low-precision transformer training that remains compatible with fused attention kernels.

Abstract

Attention scores in transformers are bilinear forms whose maximum magnitude governs overflow risk in low-precision training. We derive a \emph{rank-aware concentration inequality}: when the interaction matrix has rank , tail probabilities for decay as rather than , where is a typicality parameter. For transformer attention where , this yields -- tighter concentration than rank-agnostic bounds in modern architectures. We apply this result to FP8 training, deriving \emph{geometry-aware scale factors} that provide principled overflow guarantees without observing activations. The method computes per-layer scales from the spectral norm via implicit power iteration, includes a grouped query attention formulation that avoids key expansion, and remains compatible with fused attention kernels. Across GPT-2 XL to Llama-2-70B, geometry-aware scaling eliminates overflows in transient scenarios where delayed scaling fails, while achieving comparable downstream MMLU accuracy.
Paper Structure (98 sections, 9 theorems, 55 equations, 3 figures, 11 tables, 4 algorithms)

This paper contains 98 sections, 9 theorems, 55 equations, 3 figures, 11 tables, 4 algorithms.

Key Result

Proposition 3.1

Let $\|x_i\|_2 \leq B_X$ for all tokens. Then:

Figures (3)

  • Figure 1: Spectral norm $\sigma_{QK}^{(\ell)}$ by layer for all four models, computed on pretrained weights. Early layers consistently exhibit larger spectral norms, with layer 0 being the maximum in three of four models.
  • Figure 2: Response to $4\times$ weight spike at step 10 (GPT-2 XL). (a) Maximum scaled attention logit. (b) Scale factor over time. Delayed scaling overflows because its history buffer contains no information about the weight change. Geometry-aware scaling adapts instantaneously because the scale factor is computed from current weights.
  • Figure 3: Training loss comparison for delayed scaling, geometry-aware scaling (conservative), and geometry-aware scaling with auto-$\alpha$ on Llama-2-13B (MMLU STEM, 3000 steps). All methods converge to similar final loss ($\approx 0.012$), yet downstream MMLU accuracy differs substantially (Table \ref{['tab:mmlu']}). The vertical dashed line indicates when auto-$\alpha$ calibration completes (step 100). The conservative variant shows slightly higher loss throughout training due to reduced FP8 utilization (0.5% vs. 31.2% for auto-$\alpha$).

Theorems & Definitions (18)

  • Proposition 3.1: Naive bound
  • Proposition 3.2: Interaction bound
  • Corollary 3.3: Interaction bound is tighter
  • Proposition 3.4: Rank-aware overflow probability bound
  • Proposition 3.5: RoPE preserves norms
  • Corollary 3.6: Geometry-aware scaling extends to RoPE
  • Proposition 4.1: Implicit GQA power iteration
  • proof
  • proof
  • proof
  • ...and 8 more