Table of Contents
Fetching ...

SageBwd: A Trainable Low-bit Attention

Jintao Zhang, Marco Chen, Haoxu Wang, Kai Jiang, Ion Stoica, Joseph E. Gonzalez, Jianfei Chen, Jun Zhu

TL;DR

This work investigates why this gap occurs and demonstrates that SageBwd matches full-precision attention during pretraining, and concludes that K-smoothing remains essential for training stability, while Q-smoothing provides limited benefit during pre-training.

Abstract

Low-bit attention, such as SageAttention, has emerged as an effective approach for accelerating model inference, but its applicability to training remains poorly understood. In prior work, we introduced SageBwd, a trainable INT8 attention that quantizes six of seven attention matrix multiplications while preserving fine-tuning performance. However, SageBwd exhibited a persistent performance gap to full-precision attention (FPA) during pre-training. In this work, we investigate why this gap occurs and demonstrate that SageBwd matches full-precision attention during pretraining. Through experiments and theoretical analysis, we reach a few important insights and conclusions: (i) QK-norm is necessary for stable training at large tokens per step, (ii) quantization errors primarily arise from the backward-pass score gradient dS, (iii) reducing tokens per step enables SageBwd to match FPA performance in pre-training, and (iv) K-smoothing remains essential for training stability, while Q-smoothing provides limited benefit during pre-training.

SageBwd: A Trainable Low-bit Attention

TL;DR

This work investigates why this gap occurs and demonstrates that SageBwd matches full-precision attention during pretraining, and concludes that K-smoothing remains essential for training stability, while Q-smoothing provides limited benefit during pre-training.

Abstract

Low-bit attention, such as SageAttention, has emerged as an effective approach for accelerating model inference, but its applicability to training remains poorly understood. In prior work, we introduced SageBwd, a trainable INT8 attention that quantizes six of seven attention matrix multiplications while preserving fine-tuning performance. However, SageBwd exhibited a persistent performance gap to full-precision attention (FPA) during pre-training. In this work, we investigate why this gap occurs and demonstrate that SageBwd matches full-precision attention during pretraining. Through experiments and theoretical analysis, we reach a few important insights and conclusions: (i) QK-norm is necessary for stable training at large tokens per step, (ii) quantization errors primarily arise from the backward-pass score gradient dS, (iii) reducing tokens per step enables SageBwd to match FPA performance in pre-training, and (iv) K-smoothing remains essential for training stability, while Q-smoothing provides limited benefit during pre-training.
Paper Structure (39 sections, 11 equations, 4 figures, 2 tables)

This paper contains 39 sections, 11 equations, 4 figures, 2 tables.

Figures (4)

  • Figure 1: Pretraining loss over 78B tokens under a different number of tokens/step
  • Figure 2: Speed comparison between SageBwd and Baselines (RTX4090, headim=128).
  • Figure 3: Speed comparison between SageBwd and Baselines (RTX4090, headim=64).
  • Figure 4: Ablation of Q-smoothing and K-smoothing pretraining loss over 78B tokens under a different number of tokens/step