Table of Contents
Fetching ...

A Formal Framework for Understanding Length Generalization in Transformers

Xinting Huang, Andy Yang, Satwik Bhattamishra, Yash Sarrof, Andreas Krebs, Hattie Zhou, Preetum Nakkiran, Michael Hahn

TL;DR

The paper tackles why transformers struggle or succeed at length generalization when processing longer sequences than seen during training. It introduces an idealized inference framework and two mathematical formalisms—Limit Transformers and C-RASP—to rigorously analyze length generalization, proving that any ground-truth function expressible by a Limit Transformer with Periodic and Local positional use will generalize for sufficiently long inputs. Experimental results across algorithmic tasks and formal languages validate the theory’s predictive power, showing length generalization correlates with CRASP expressiveness and that certain tasks are inherently non-generalizable under the proposed constraints. Overall, the work bridges empirical observations with formal guarantees, offering a principled path toward predicting and understanding length generalization in transformers.

Abstract

A major challenge for transformers is generalizing to sequences longer than those observed during training. While previous works have empirically shown that transformers can either succeed or fail at length generalization depending on the task, theoretical understanding of this phenomenon remains limited. In this work, we introduce a rigorous theoretical framework to analyze length generalization in causal transformers with learnable absolute positional encodings. In particular, we characterize those functions that are identifiable in the limit from sufficiently long inputs with absolute positional encodings under an idealized inference scheme using a norm-based regularizer. This enables us to prove the possibility of length generalization for a rich family of problems. We experimentally validate the theory as a predictor of success and failure of length generalization across a range of algorithmic and formal language tasks. Our theory not only explains a broad set of empirical observations but also opens the way to provably predicting length generalization capabilities in transformers.

A Formal Framework for Understanding Length Generalization in Transformers

TL;DR

The paper tackles why transformers struggle or succeed at length generalization when processing longer sequences than seen during training. It introduces an idealized inference framework and two mathematical formalisms—Limit Transformers and C-RASP—to rigorously analyze length generalization, proving that any ground-truth function expressible by a Limit Transformer with Periodic and Local positional use will generalize for sufficiently long inputs. Experimental results across algorithmic tasks and formal languages validate the theory’s predictive power, showing length generalization correlates with CRASP expressiveness and that certain tasks are inherently non-generalizable under the proposed constraints. Overall, the work bridges empirical observations with formal guarantees, offering a principled path toward predicting and understanding length generalization in transformers.

Abstract

A major challenge for transformers is generalizing to sequences longer than those observed during training. While previous works have empirically shown that transformers can either succeed or fail at length generalization depending on the task, theoretical understanding of this phenomenon remains limited. In this work, we introduce a rigorous theoretical framework to analyze length generalization in causal transformers with learnable absolute positional encodings. In particular, we characterize those functions that are identifiable in the limit from sufficiently long inputs with absolute positional encodings under an idealized inference scheme using a norm-based regularizer. This enables us to prove the possibility of length generalization for a rich family of problems. We experimentally validate the theory as a predictor of success and failure of length generalization across a range of algorithmic and formal language tasks. Our theory not only explains a broad set of empirical observations but also opens the way to provably predicting length generalization capabilities in transformers.
Paper Structure (143 sections, 37 theorems, 214 equations, 11 figures, 11 tables)

This paper contains 143 sections, 37 theorems, 214 equations, 11 figures, 11 tables.

Key Result

Theorem 1

Let $f$ be the target function expressible by a single Limit Transformer at all input lengths, subject to restrictions on the use of positional information. Choose transformers $T_n$ ($n=1,2,3,\dots$) with context size $n$, where $T_n$ reproduces the behavior of $f$ up to length $\frac{n}{2}$, while

Figures (11)

  • Figure 1: Experimental results (y axis: accuracy), at lengths $\leq 50$ (Bin 1, training), $[51, 100]$ (Bin 2), and $[101, 150]$ (Bin 3, generalization), for APE (solid) and NoPE (dotted). Green lines indicate that we found a C-RASP program ($\textbf{C-RASP}[\text{periodic},\text{local}]$ for APE, $\textbf{C-RASP}[\emptyset]$ for NoPE), red lines indicate that we proved nonexistence, or found no program. Random baselines are indicated in gray in (left), and very close to zero in (right). On the algorithmic problems (left), we replicate prior empirical findings; C-RASP expressiveness predicts observed length generalization. On the regular languages (right, with same $x$ and $y$-axes as left, Table \ref{['tab:finite-state-expressiveness']}), length generalization tracks $\textbf{C-RASP}$ expressiveness established in Lemma \ref{['lemma:regular-languages-crasp']} ((1) = $(aa)^*$, (17) = $\Sigma^*be^*$) and other results (see Appendix \ref{['sec:language-definitions']}). C-RASP expressiveness performs much better than circuit complexity and standard notions of regular language complexity in predicting length generalization (Appendix, Figures \ref{['fig:figure-ac0']}--\ref{['fig:star-free-ness']}).
  • Figure 2: Detailed results for regular languages with language names, corresponding to the right part of Figure \ref{['fig:results']} but with individual languages labeled.
  • Figure 3: Membership in the circuit complexity class $\textbf{AC}^0$ does not predict transformers' length generalization on algorithmic problems (top) or regular languages (bottom). Prior work has often linked the expressiveness of transformers to circuit complexity hahn2020theoreticalhao2022formalmerrill2023parallelismstrobl2023averagehardbarcelo2024logical. All tasks included in our experiments are in the class $\textbf{TC}^0$, the tightest known upper bound on transformers' expressiveness. A well-known circuit complexity class within $\textbf{TC}^0$ is $\textbf{AC}^0$, known to upper-bound the power of certain hard-attention models of transformers hao2022formalbarcelo2024logical, which may raise hopes that it helps understand transformers' practical abilities. However, membership in this class does not predict transformers' length generalization behavior. On the algorithmic problems, there is no apparent correlation at all; majority-type problems, which the attention mechanism can easily implement, are not in $\textbf{AC}^0$, but problems with super-logarithmic communication complexity such as copying and addition (Corollary \ref{['corr:comm-comp-no-gen']}) are contained. On the regular languages, $\textbf{AC}^0$ exactly covers the class $\textbf{FO}[reg]$. This class can be proven to include all regular languages in C-RASP, but it also includes various languages that transformers length-generalize poorly on, such as Tomita-3. A natural subclass, obtained by restricting the size of $\textbf{AC}^0$ circuits to a linear number of wires, yields the class $\textbf{FO}_2[Reg]$CadilhacP22, which does not match transformers' behavior well either, e.g. it includes $\{0,1,2\}^*02^*$ (bottom right, equals $\Sigma^*be^*$ from Lemma \ref{['lemma:regular-languages-crasp']}) but does not include D-12. Taken together, established circuit complexity classes do not account for Transformers' length generalization behavior. Compare to C-RASP results in Figures \ref{['fig:results']} and \ref{['fig:detailed-formal-languages']}.
  • Figure 4: (1) Comparing length-generalization with a standard notion of the complexity of finite-state languages: Star-free languages (green) do not require modular counting mcnaughton1971counter, have simpler algebraic representations in terms of group-free monoids schutzenberger1965finite, are easily represented by modern state-space models Sarrof2024SSMs, and match the expressiveness of a formal model of hard attention Transformers yang2023masked. However, they do not consistently lead to length generalization in transformers, which on the other hand length-generalize on some non-star-free languages such as $(aa)^*$. The expressiveness of $\textbf{C-RASP}$ correctly accounts for the observed behavior. (2) Within the star-free languages, a standard complexity metric is dot-depth, with increased dot-depth indicating increased complexity (non-star-free languages are plotted in gray color). Dot-depth does not predict length generalization, which succeeds on some languages at dot depths 1 and 12 and fails at some languages at intermediate depth. See Figure \ref{['fig:figure-ac0']} for further discussion regarding another existing notion of complexity, circuit complexity, also much less successful than $\textbf{C-RASP}$ expressiveness at predicting length generalization. Compare to C-RASP results in Figures \ref{['fig:results']} and \ref{['fig:detailed-formal-languages']}.
  • Figure 5: Appendix \ref{['sec:empirical-pos-gen']}: MSE loss in fitting (length $= 50$) and generalizing (higher lengths) functions $\phi(\cdot,\cdot)$ with products ${\bm{p}}_j^T {\bm{K}}^T {\bm{Q}} {\bm{p}}_i$. We show local functions testing if $j=i-c$ (left), if $j > i-c$ (center), periodic functions testing whether $i-j \equiv c_2\ (\text{mod} c_1)$ (right). We show result at small (top, $d=32$) and high (bottom, $d=256$) dimensionality. Local functions length-generalize well when dimensionality is high (bottom left, bottom center); generalization is more successful with functions concentrated on few pairs (bottom left is nonzero at only one value of $j-i$; bottom center is nonzero at $c$ different values of $j-i$). Periodic functions length-generalize well when dimensionality is low (top right). The results match the distinct roles played by local and periodic functions in our theoretical constructions: Periodic functions are mediated by bounded-rank products (Lemma \ref{['lemma:deducing-periodicity-fidi']}), local functions are mediated by the products ${\bm{p}}^T {\bm{K}}_{l,h}^T {\bm{Q}} {\bm{p}}$.
  • ...and 6 more figures

Theorems & Definitions (87)

  • Theorem 1: Informal Version of Theorem \ref{['thm:guarantee']}
  • Definition 2
  • Definition 3
  • Definition 4: Hypothesis Class
  • Definition 5: Regularizer
  • Definition 6: Inference Procedure
  • Theorem 7: Guaranteed Length Generalization in the Limit
  • Definition 8: $\textbf{C-RASP}$
  • Theorem 9
  • Theorem 10
  • ...and 77 more