Table of Contents
Fetching ...

On Vanishing Variance in Transformer Length Generalization

Ruining Li, Gabrijel Boduljak, Jensen, Zhou

TL;DR

This paper introduces a vanishing-variance perspective on transformer length generalization, showing that as sequence length $N$ increases, the variance of attention outputs decays and induces distribution shift that harms generalization to longer inputs. It combines theoretical reasoning with empirical studies on order-invariant tasks, revealing that applying LayerNorm after attention outputs stabilizes global statistics and improves out-of-distribution performance, though it does not fully eliminate the decay. Ablation studies indicate that normalization—particularly LayerNorm—significantly mitigates length-related degradation, while standardization also helps albeit with less capacity. The work suggests a path toward more robust, length-invariant architectures and highlights the need for architectural design beyond ad-hoc positional encodings to ensure reliable long-sequence reasoning in Transformers.

Abstract

It is a widely known issue that Transformers, when trained on shorter sequences, fail to generalize robustly to longer ones at test time. This raises the question of whether Transformer models are real reasoning engines, despite their impressive abilities in mathematical problem solving and code synthesis. In this paper, we offer a vanishing variance perspective on this issue. To the best of our knowledge, we are the first to demonstrate that even for today's frontier models, a longer sequence length results in a decrease in variance in the output of the multi-head attention modules. On the argmax retrieval and dictionary lookup tasks, our experiments show that applying layer normalization after the attention outputs leads to significantly better length generalization. Our analyses attribute this improvement to a reduction-though not a complete elimination-of the distribution shift caused by vanishing variance.

On Vanishing Variance in Transformer Length Generalization

TL;DR

This paper introduces a vanishing-variance perspective on transformer length generalization, showing that as sequence length increases, the variance of attention outputs decays and induces distribution shift that harms generalization to longer inputs. It combines theoretical reasoning with empirical studies on order-invariant tasks, revealing that applying LayerNorm after attention outputs stabilizes global statistics and improves out-of-distribution performance, though it does not fully eliminate the decay. Ablation studies indicate that normalization—particularly LayerNorm—significantly mitigates length-related degradation, while standardization also helps albeit with less capacity. The work suggests a path toward more robust, length-invariant architectures and highlights the need for architectural design beyond ad-hoc positional encodings to ensure reliable long-sequence reasoning in Transformers.

Abstract

It is a widely known issue that Transformers, when trained on shorter sequences, fail to generalize robustly to longer ones at test time. This raises the question of whether Transformer models are real reasoning engines, despite their impressive abilities in mathematical problem solving and code synthesis. In this paper, we offer a vanishing variance perspective on this issue. To the best of our knowledge, we are the first to demonstrate that even for today's frontier models, a longer sequence length results in a decrease in variance in the output of the multi-head attention modules. On the argmax retrieval and dictionary lookup tasks, our experiments show that applying layer normalization after the attention outputs leads to significantly better length generalization. Our analyses attribute this improvement to a reduction-though not a complete elimination-of the distribution shift caused by vanishing variance.

Paper Structure

This paper contains 19 sections, 2 theorems, 5 equations, 4 figures, 6 tables.

Key Result

Proposition 1

Consider a trained attention module with weights $\mathbf{W}_Q, \mathbf{W}_K, \mathbf{W}_V, \mathbf{W}_O$. Let $\mathbf{X} = \left[\mathbf{x}_1 \Vert \mathbf{x}_2 \Vert \dots \Vert \mathbf{x}_N \right ]^{\top}$ denote an input sequence of length $N$. If (1) $\mathbf{x}_1, \mathbf{x}_2, \dots, \mathb where $\mathbf{x}_n, \mathbf{y} \in \mathbb{R}^D$ and $\mathbf{Q} \in \mathbb{R}^{1\times D}, \math

Figures (4)

  • Figure 1: Standard deviation of a fixed component in attention outputs from the first layer of Llama-3.2-1B (log-log scale) over multiple input sequences of fixed length $N$. Even in the latest LLMs, increasing sequence length $N$ reduces the variance of attended outputs, significantly degrading accuracy on long sequences.
  • Figure 2: Distribution of 5 individual features in attention outputs $\mathbf{O}$ across batches. Each color represents a different feature. As input sequence length $N$ increases from $2^4$ to $2^{14}$, feature variance decreases, and values concentrate around their mean. Layer normalization (bottom) scales and shifts features to maintain relatively constant global variance, likely explaining its superior length generalization compared to the Baseline (top).
  • Figure 3: Layer normalization helps mitigate distribution shift in attention outputs. (Left) shows the drift in global mean as input sequence length deviates from the training distribution. The mean is normalized by the training global variance to eliminate scale differences. (Right) shows the decay in global variance. All results are averaged across $32$k random input sequences of the fixed length.
  • Figure 4: Heatmap of the largest $16$ attention weights, computed over $32$ examples. Layer normalization mitigates dispersion, which is inevitable as sequence length increases velivckovic2024softmax.

Theorems & Definitions (4)

  • Proposition 1: The vanishing variance problem
  • proof
  • Proposition 1: The vanishing variance problem
  • proof