Table of Contents
Fetching ...

Controlling changes to attention logits

Ben Anson, Laurence Aitchison

TL;DR

This work targets stability during transformer pretraining by constraining changes in attention logits rather than their magnitudes. It introduces QuacK, a method that assigns parameter-dependent learning rates $\\eta_Q \\propto \\|\\mathbf{W}_K\\|^{-1}$ and $\\eta_K \\propto \\|\\mathbf{W}_Q\\|^{-1}$ to bound logit updates, supported by a lemma showing the worst-case logit change is bounded independently of weight size. Empirically, QuacK enables higher base learning rates, matches QK norm stability in standard MHA, and outperforms alternatives like QK clip in the MLA setting, while being cheaper and more widely applicable. This offers a practical, normalization-free path to stable, scalable transformer pretraining, especially for architectures that use MLA.

Abstract

Stability of neural network weights is critical when training transformer models. The query and key weights are particularly problematic, as they tend to grow large without any intervention. Applying normalization to queries and keys, known as `QK norm', fixes stability issues in practice, but is not always applicable. For example, QK norm is not compatible with Multi Latent Attention (MLA) because QK norm requires full materialization of queries and keys during inference, which is not done in MLA. In this paper we suggest that controlling the changes to logits is important for stability. We show that these changes are controllable by assigning parameter-dependent learning rates to the query and key weights. We find that our cheap intervention allows us to increase the base learning rate of the network, outperform other methods in the MLA setting, and achieve performance competitive with QK norm when using Multi-head Attention.

Controlling changes to attention logits

TL;DR

This work targets stability during transformer pretraining by constraining changes in attention logits rather than their magnitudes. It introduces QuacK, a method that assigns parameter-dependent learning rates and to bound logit updates, supported by a lemma showing the worst-case logit change is bounded independently of weight size. Empirically, QuacK enables higher base learning rates, matches QK norm stability in standard MHA, and outperforms alternatives like QK clip in the MLA setting, while being cheaper and more widely applicable. This offers a practical, normalization-free path to stable, scalable transformer pretraining, especially for architectures that use MLA.

Abstract

Stability of neural network weights is critical when training transformer models. The query and key weights are particularly problematic, as they tend to grow large without any intervention. Applying normalization to queries and keys, known as `QK norm', fixes stability issues in practice, but is not always applicable. For example, QK norm is not compatible with Multi Latent Attention (MLA) because QK norm requires full materialization of queries and keys during inference, which is not done in MLA. In this paper we suggest that controlling the changes to logits is important for stability. We show that these changes are controllable by assigning parameter-dependent learning rates to the query and key weights. We find that our cheap intervention allows us to increase the base learning rate of the network, outperform other methods in the MLA setting, and achieve performance competitive with QK norm when using Multi-head Attention.

Paper Structure

This paper contains 9 sections, 2 theorems, 20 equations, 4 figures, 2 algorithms.

Key Result

lemma 1

Let $\mathbf{W}_Q,\mathbf{W}_K\in\mathbb{R}^{d_\text{model}\times d_\text{head}}$ be weight matrices corresponding to a particular attention head, and consider the worst-case change in logits, for unit normed input, where $\mathbf{W} = d_\text{head}^{-1/2}\mathbf{W}_Q^\top \mathbf{W}_K$. Suppose that the steps for $\mathbf{W}_Q$ and $\mathbf{W}_K$ are given by $\Delta\mathbf{W}_{Q/K} = -\eta_{Q/K

Figures (4)

  • Figure 1: Learning rate of query/key matrices is a critical factor for transformer pretraining stability. Here, 4 models are trained with a large base learning rate of $\eta = \mathtt{3e-2}$ for each parameter. Decreasing the learning rates of query and key weights alone (by a factor of $\eta_{Q/K}$), fully stabilizes pretraining. QK norm is shown to illustrate a stable baseline.
  • Figure 2: Validation losses when training each method with $\mathtt{attn}\in\{\mathtt{MHA},\,\mathtt{MLA}\}$, and learning rates, $\eta \in\{\mathtt{3e-4}, \mathtt{3e-3}, \mathtt{3e-2}\}$. QK clip is unstable at high learning rates (it is omitted from the bottom right plot due to loss $\gg 2$). QK norm is overall the most performant, but it is not appropriate for use with MLA at inference-time for efficiency reasons (illustrated via dashed yellow line in the MLA row). QuacK is a sensible alternative, as it is stable in the high LR setting, performant, and applicable in the MLA setting.
  • Figure 3: Performance differences when applying Algorithm \ref{['alg:mha_quack']} with different norms are small. We show validation losses when training a small model ($\sim 100$M parameters) with Algorithm \ref{['alg:mha_quack']} to modulate the query and key weight learning rates. Different curves show results with different values of the hyperparameter $\tau$ and measuring the query and key weights with either Frobenius or spectral norm.
  • Figure 4: Max logit (left) and average absolute change in logit throughout training (right) with a base learning rate of $\eta = \mathtt{3e-3}$. Here we show the middle head of the middle layer (head 16 and layer 8) while training with MLA.

Theorems & Definitions (3)

  • lemma 1
  • lemma 1
  • proof