Table of Contents
Fetching ...

Scale-invariant Attention

Ben Anson, Xi Wang, Laurence Aitchison

TL;DR

This work tackles the challenge of generalizing attention to longer contexts in Transformers by introducing two scale-invariant desiderata: scale-invariant total attention and scale-invariant attention sparsity. It derives a simple, position-dependent logit transformation under a Gaussian-logit model that provably achieves these properties and integrates it with p-RoPE. Empirically, scale-invariant attention improves long-context language modeling and zero-shot generalization to longer contexts, and maintains robust long-context retrieval on needle-in-a-haystack tasks, outperforming several baselines. Limitations include evaluations at relatively small scale (162M/304M) and reliance on Gaussian-logit assumptions, with promising indications for extension to larger models and broader attention variants. The results suggest a practical path to enhanced long-context processing without routing through retrieval or windowing mechanisms alone.

Abstract

One persistent challenge in LLM research is the development of attention mechanisms that are able to generalise from training on shorter contexts to inference on longer contexts. We propose two conditions that we expect all effective long context attention mechanisms to have: scale-invariant total attention, and scale-invariant attention sparsity. Under a Gaussian assumption, we show that a simple position-dependent transformation of the attention logits is sufficient for these conditions to hold. Experimentally we find that the resulting scale-invariant attention scheme gives considerable benefits in terms of validation loss when zero-shot generalising from training on short contexts to validation on longer contexts, and is effective at long-context retrieval.

Scale-invariant Attention

TL;DR

This work tackles the challenge of generalizing attention to longer contexts in Transformers by introducing two scale-invariant desiderata: scale-invariant total attention and scale-invariant attention sparsity. It derives a simple, position-dependent logit transformation under a Gaussian-logit model that provably achieves these properties and integrates it with p-RoPE. Empirically, scale-invariant attention improves long-context language modeling and zero-shot generalization to longer contexts, and maintains robust long-context retrieval on needle-in-a-haystack tasks, outperforming several baselines. Limitations include evaluations at relatively small scale (162M/304M) and reliance on Gaussian-logit assumptions, with promising indications for extension to larger models and broader attention variants. The results suggest a practical path to enhanced long-context processing without routing through retrieval or windowing mechanisms alone.

Abstract

One persistent challenge in LLM research is the development of attention mechanisms that are able to generalise from training on shorter contexts to inference on longer contexts. We propose two conditions that we expect all effective long context attention mechanisms to have: scale-invariant total attention, and scale-invariant attention sparsity. Under a Gaussian assumption, we show that a simple position-dependent transformation of the attention logits is sufficient for these conditions to hold. Experimentally we find that the resulting scale-invariant attention scheme gives considerable benefits in terms of validation loss when zero-shot generalising from training on short contexts to validation on longer contexts, and is effective at long-context retrieval.

Paper Structure

This paper contains 28 sections, 6 theorems, 47 equations, 14 figures, 4 tables.

Key Result

lemma 1

Consider a set of random variables, $\{L_t\}$, representing attention logits. Let $\tau>0$ be a lengthscale parameter, and $\alpha>0$ a multiplicative constant. If the attention logits satisfy, then we have scale-invariant total attention (Def. def:scale_inv_total_attn).

Figures (14)

  • Figure 1: Scale-invariant attention controls the entropy without sacrificing attention over the local context. We consider three metrics for attention schemes: (left) the global attention entropy, (middle) entropy within particular ranges of tokens (e.g. 10--100), and (right) total attention to the previous 100 tokens. The top row uses IID Gaussian logits, following our theoretical approach in Sec. \ref{['sec:logit_characteristics']}. For LogN, the IID logits are multiplied by $s\log N$, where $N$ is the sequence length and $s=0.4$. The bottom row uses attention logits sampled from models trained with $p$-RoPE and 'No scale', LogN, and our scale-invariant transform. With no logit scaling, the attention becomes increasingly diffuse as the context grows (i.e. the distribution over logits has high entropy). LogN scaling reduces the entropy and thus ensures that attention remains sparse even at longer contexts. However, LogN still forfeits the ability to attend to the local context (e.g. 100 most recent) tokens. Scale-invariant attention strikes a balance between low entropy and attending over the local context.
  • Figure 2: Expected entropy of scale-invariant attention at different scales is sub-logarithmic. Here, we sample sequences of independent standard Gaussian logits, and apply the scale-invariant attention transformation. We estimate the expected entropy in ranges $[t, t\Delta)$, where the size of the range is controlled by $t$ (x-axis) and $\Delta$ (line color). We see that this expected entropy measure scales sub-logarithmically (left), and with the right plot suggesting a $\sim \sqrt{\log(t)}$ scaling. The dashed lines show a best linear fit.
  • Figure 3: Validation losses throughout training of a 162M parameter GPT-2-like model with different attention mechanisms (our scale-invariant scheme shown in black). NoPE is omitted in all but top-left and bottom-left panels to avoid excessive zooming out (due to high loss). The gray dashed line shows the baseline of 3.28. The validation loss for many methods increases in unison at steps $\sim$1500 and $\sim$3250 (see (b), left) despite aggregating over seeds; this is due to a fixed training and validation data ordering.
  • Figure 4: Validation losses throughout training for a 304M parameter model.
  • Figure 5: Validation losses at 4k (left), 16k (middle), and 64k (right) context lengths, for a GPT-2-like model trained with scale-invariant attention for varying $\tau$. The models were trained at 4k context length.
  • ...and 9 more figures

Theorems & Definitions (13)

  • Definition 3.1: Scale-invariant total attention
  • Definition 3.2: Scale-invariant unnormalised attention sparsity
  • Definition 3.3: Weak scale-invariant attention sparsity
  • Definition 3.4: Strong scale-invariant attention sparsity
  • lemma 1
  • lemma 2
  • theorem 1
  • lemma 2
  • proof
  • lemma 2
  • ...and 3 more