Table of Contents
Fetching ...

Attn-QAT: 4-Bit Attention With Quantization-Aware Training

Peiyuan Zhang, Matthew Noto, Wenxuan Tan, Chengquan Jiang, Will Lin, Wei Zhou, Hao Zhang

TL;DR

Attn-QAT is proposed and implemented, which recovers the quality drop from FP4 attention without explicit outlier-mitigation heuristics used in prior FP4 attention, and delivers up to a 1.5x speedup on an RTX 5090.

Abstract

Achieving reliable 4-bit attention is a prerequisite for end-to-end FP4 computation on emerging FP4-capable GPUs, yet attention remains the main obstacle due to FP4's tiny dynamic range and attention's heavy-tailed activations. This paper presents the first systematic study of 4-bit quantization-aware training (QAT) for attention. We find that "drop-in" QAT, which naively combines an FP4 forward pass with a high-precision Flash Attention (FA)-style backward pass, leads to training instability. We identify two key principles for stable FP4 attention: (1) matching low-precision recomputation of attention scores in the backward pass, and (2) resolving implicit precision assumptions in FA's gradient calculation. Based on these insights, we propose Attn-QAT and implement fused Triton kernels for training as well as FP4 inference kernels. Across diffusion and language models, Attn-QAT recovers the quality drop from FP4 attention without explicit outlier-mitigation heuristics used in prior FP4 attention, and delivers up to a 1.5x speedup on an RTX 5090. Video demos can be found at https://drive.google.com/drive/folders/190F6xbBDUF2kGQYIcXBt3ehSYij5jlim?usp=sharing.

Attn-QAT: 4-Bit Attention With Quantization-Aware Training

TL;DR

Attn-QAT is proposed and implemented, which recovers the quality drop from FP4 attention without explicit outlier-mitigation heuristics used in prior FP4 attention, and delivers up to a 1.5x speedup on an RTX 5090.

Abstract

Achieving reliable 4-bit attention is a prerequisite for end-to-end FP4 computation on emerging FP4-capable GPUs, yet attention remains the main obstacle due to FP4's tiny dynamic range and attention's heavy-tailed activations. This paper presents the first systematic study of 4-bit quantization-aware training (QAT) for attention. We find that "drop-in" QAT, which naively combines an FP4 forward pass with a high-precision Flash Attention (FA)-style backward pass, leads to training instability. We identify two key principles for stable FP4 attention: (1) matching low-precision recomputation of attention scores in the backward pass, and (2) resolving implicit precision assumptions in FA's gradient calculation. Based on these insights, we propose Attn-QAT and implement fused Triton kernels for training as well as FP4 inference kernels. Across diffusion and language models, Attn-QAT recovers the quality drop from FP4 attention without explicit outlier-mitigation heuristics used in prior FP4 attention, and delivers up to a 1.5x speedup on an RTX 5090. Video demos can be found at https://drive.google.com/drive/folders/190F6xbBDUF2kGQYIcXBt3ehSYij5jlim?usp=sharing.
Paper Structure (31 sections, 11 equations, 8 figures, 4 tables, 3 algorithms)

This paper contains 31 sections, 11 equations, 8 figures, 4 tables, 3 algorithms.

Figures (8)

  • Figure 1: Both NVFP4 attention and SageAttention3 suffer from a significant quality drop on Wan 2.1 14B. Our proposed method, Attn-QAT, recovers the quality drop by using quantization-aware training. Note that temporal inconsistency is hard to visualize in sampled frames. We attach video samples in Appendix \ref{['sec:qualitative']}without cherry-picking to better showcase the superior quality of Attn-QAT.
  • Figure 2: Win–Tie–Lose blind human evaluation on 99 randomly sampled VBench prompts for Wan 2.1 14B. Attn-QAT matches BF16 attention in perceived visual quality.
  • Figure 3: Training dynamics for diffusion and language models. (a–b) Gradient norm and loss during Wan 2.1 1.3B finetuning under different Attn-QAT configurations. (c) Finetuning loss curves of Qwen3-14B comparing BF16 attention and Attn-QAT.
  • Figure 4: The Triton forward pass (fake quantization with BF16 GEMM and FP4 emulation) and the CUDA forward pass (real FP4 quantization and FP4 GEMM) produce visually indistinguishable videos, indicating close numerical agreement between the two implementations.
  • Figure 5: Kernel throughput on RTX 5090. We compare attention kernel performance with head dimensions 128 (top) and 64 (bottom), using a batch size of 16 and 16 attention heads. All results report end-to-end throughput.
  • ...and 3 more figures