Table of Contents
Fetching ...

The Role of Sparsity for Length Generalization in Transformers

Noah Golowich, Samy Jelassi, David Brandfonbrener, Sham M. Kakade, Eran Malach

TL;DR

The paper presents a principled theory for length generalization in decoder-only transformers, showing that if each predicted token depends on a fixed number $k$ of past tokens (i.e., $k$-sparse planted correlations), then length extrapolation to longer contexts is provable for a sparse functional attention class under locality. It further demonstrates that position coupling can remove strict locality requirements, and introduces Predictive Position Coupling (PPC) to handle input-dependent coupled positions. The authors validate the theory with synthetic parity tasks and natural-language experiments, showing that length generalization improves as sparsity decreases and that PPC expands the range of tasks for which position coupling is effective. The work also situates itself among prior provable-length-generalization results by enabling nonlinear attention heads and distributional assumptions, and it offers practical techniques (PoSE, PPC) that can enhance real-world long-context generalization. Overall, the results provide a theoretical and empirical blueprint for designing transformers with robust length generalization grounded in data sparsity and structured positional information.

Abstract

Training large language models to predict beyond their training context lengths has drawn much attention in recent years, yet the principles driving such behavior of length generalization remain underexplored. We propose a new theoretical framework to study length generalization for the next-token prediction task, as performed by decoder-only transformers. Conceptually, we show that length generalization occurs as long as each predicted token depends on a small (fixed) number of previous tokens. We formalize such tasks via a notion we call $k$-sparse planted correlation distributions, and show that an idealized model of transformers which generalize attention heads successfully length-generalize on such tasks. As a bonus, our theoretical model justifies certain techniques to modify positional embeddings which have been introduced to improve length generalization, such as position coupling. We support our theoretical results with experiments on synthetic tasks and natural language, which confirm that a key factor driving length generalization is a ``sparse'' dependency structure of each token on the previous ones. Inspired by our theory, we introduce Predictive Position Coupling, which trains the transformer to predict the position IDs used in a positional coupling approach. Predictive Position Coupling thereby allows us to broaden the array of tasks to which position coupling can successfully be applied to achieve length generalization.

The Role of Sparsity for Length Generalization in Transformers

TL;DR

The paper presents a principled theory for length generalization in decoder-only transformers, showing that if each predicted token depends on a fixed number of past tokens (i.e., -sparse planted correlations), then length extrapolation to longer contexts is provable for a sparse functional attention class under locality. It further demonstrates that position coupling can remove strict locality requirements, and introduces Predictive Position Coupling (PPC) to handle input-dependent coupled positions. The authors validate the theory with synthetic parity tasks and natural-language experiments, showing that length generalization improves as sparsity decreases and that PPC expands the range of tasks for which position coupling is effective. The work also situates itself among prior provable-length-generalization results by enabling nonlinear attention heads and distributional assumptions, and it offers practical techniques (PoSE, PPC) that can enhance real-world long-context generalization. Overall, the results provide a theoretical and empirical blueprint for designing transformers with robust length generalization grounded in data sparsity and structured positional information.

Abstract

Training large language models to predict beyond their training context lengths has drawn much attention in recent years, yet the principles driving such behavior of length generalization remain underexplored. We propose a new theoretical framework to study length generalization for the next-token prediction task, as performed by decoder-only transformers. Conceptually, we show that length generalization occurs as long as each predicted token depends on a small (fixed) number of previous tokens. We formalize such tasks via a notion we call -sparse planted correlation distributions, and show that an idealized model of transformers which generalize attention heads successfully length-generalize on such tasks. As a bonus, our theoretical model justifies certain techniques to modify positional embeddings which have been introduced to improve length generalization, such as position coupling. We support our theoretical results with experiments on synthetic tasks and natural language, which confirm that a key factor driving length generalization is a ``sparse'' dependency structure of each token on the previous ones. Inspired by our theory, we introduce Predictive Position Coupling, which trains the transformer to predict the position IDs used in a positional coupling approach. Predictive Position Coupling thereby allows us to broaden the array of tasks to which position coupling can successfully be applied to achieve length generalization.

Paper Structure

This paper contains 72 sections, 9 theorems, 52 equations, 8 figures, 5 tables.

Key Result

Proposition 3.1

Fix any $L, \bar{L} \in \mathbb N$ with $L < \bar{L}$, and a hypothesis class $\mathcal{H} \subset \mathcal{Y}^{\mathcal{V}^\star}$. Let $\mathcal{P}_1, \ldots, \mathcal{P}_{L}$ be $\mathcal{H}$-realizable distributions, realized by $h^\star$. Suppose that $\epsilon > 0$ and $\hat{h}$ defined in eq:

Figures (8)

  • Figure 1: Length generalization for the sparse parity task with $K_{\mathsf{train}} \in \{4,6,8,10,12\}$ for various values of $k_{\mathsf{test}}$.
  • Figure 2: Parity with scratchpad and Predictive Position Coupling
  • Figure 3: Length generalization figure for variable assignment experiments (\ref{['sec:parity-scratchpad']}). \ref{['fig:pointer-cot-absshift', 'fig:pointer-rope-pose', 'fig:pointer-cot-rope-pose']} show length generalization behavior with modifications that replace Predictive Position Coupling; length generalization is significantly worse than that in \ref{['fig:pointer-cot']}. \ref{['fig:pointer-cot-absshift', 'fig:pointer-cot-rope-pose']} report the full-string accuracy on the scratchpad; \ref{['fig:pointer-rope-pose']} (which does not use a scratchpad) reports the accuracy of the model at predicting the answer token. Moreover, consistent with takeaway \ref{['it:ta-sparsity']} it appears that, even when restricting to RoPE with PoSE, having a scratchpad is superior to not having one (even when accuracy in the scratchpad case is measured by full-string correctness, as in \ref{['fig:pointer-cot-rope-pose']}), though there is significant variance between different training runs.
  • Figure 4: Cross-entropy loss differences obtained by unmasking $k$ influential tokens ($x$-axis) or all of the first 64 tokens ($y$-axis).
  • Figure 5: Illustration of the computation of $\mathcal{L}_{\mathsf{short}}, \mathcal{L}_{\mathsf{long}}, \mathcal{L}_{\mathsf{short,sparse}}$ with $\bar{L} = 8, L = 2$, $k = 2$, $J_1 = 2, J_2 = 4$. Tokens shaded gray are masked (i.e., not attended to) while those shaded blue are not masked (i.e., are attended to).
  • ...and 3 more figures

Theorems & Definitions (23)

  • Definition 3.1: Length generalization
  • Proposition 3.1
  • Definition 3.2: $k$-sparse planted correlations
  • Definition 3.3: Sparse functional attention class
  • Proposition 3.2: Informal version of \ref{['prop:model-attn-head-formal']}
  • Theorem 4.3: Provable length generalization
  • Definition 4.1: Local position coupling
  • Proposition 4.4: Informal version of \ref{['prop:pc-length-extrap']}
  • Remark 4.2: Theoretical justification for PoSE
  • Remark 5.1
  • ...and 13 more