Table of Contents
Fetching ...

Transformers Provably Learn Chain-of-Thought Reasoning with Length Generalization

Yu Huang, Zixin Wen, Aarti Singh, Yuejie Chi, Yuxin Chen

TL;DR

The paper investigates whether gradient-descent trained transformers can learn chain-of-thought (CoT) reasoning with length generalization on synthetic LEGO state-tracking tasks. Using a minimal one-layer NoPE transformer with softmax attention and an FFN, it analyzes two algebraic action structures—simply transitive and symmetry group actions—and proves three main results: provable CoT learning with length generalization for simply transitive tasks, a mechanistic dependence of length generalization on the task structure via attention concentration, and recursive self-training that provably extends reasoning length for symmetry tasks. The results bridge expressiveness gaps between TC$^0$ and NC$^1$ under CoT, and provide optimization-based guarantees that gradient-based training can discover length-generalizing CoT solutions, complemented by empirical LEGO experiments. The empirical findings validate the theory, showing distinct length-generalization patterns and attention-concentration mechanisms, and demonstrating self-training as a viable route to extend CoT length in harder reasoning tasks with structured distractors.

Abstract

The ability to reason lies at the core of artificial intelligence (AI), and challenging problems usually call for deeper and longer reasoning to tackle. A crucial question about AI reasoning is whether models can extrapolate learned reasoning patterns to solve harder tasks with longer chain-of-thought (CoT). In this work, we present a theoretical analysis of transformers learning on synthetic state-tracking tasks with gradient descent. We mathematically prove how the algebraic structure of state-tracking problems governs the degree of extrapolation of the learned CoT. Specifically, our theory characterizes the length generalization of transformers through the mechanism of attention concentration, linking the retrieval robustness of the attention layer to the state-tracking task structure of long-context reasoning. Moreover, for transformers with limited reasoning length, we prove that a recursive self-training scheme can progressively extend the range of solvable problem lengths. To our knowledge, we provide the first optimization guarantee that constant-depth transformers provably learn $\mathsf{NC}^1$-complete problems with CoT, significantly going beyond prior art confined in $\mathsf{TC}^0$, unless the widely held conjecture $\mathsf{TC}^0 \neq \mathsf{NC}^1$ fails. Finally, we present a broad set of experiments supporting our theoretical results, confirming the length generalization behaviors and the mechanism of attention concentration.

Transformers Provably Learn Chain-of-Thought Reasoning with Length Generalization

TL;DR

The paper investigates whether gradient-descent trained transformers can learn chain-of-thought (CoT) reasoning with length generalization on synthetic LEGO state-tracking tasks. Using a minimal one-layer NoPE transformer with softmax attention and an FFN, it analyzes two algebraic action structures—simply transitive and symmetry group actions—and proves three main results: provable CoT learning with length generalization for simply transitive tasks, a mechanistic dependence of length generalization on the task structure via attention concentration, and recursive self-training that provably extends reasoning length for symmetry tasks. The results bridge expressiveness gaps between TC and NC under CoT, and provide optimization-based guarantees that gradient-based training can discover length-generalizing CoT solutions, complemented by empirical LEGO experiments. The empirical findings validate the theory, showing distinct length-generalization patterns and attention-concentration mechanisms, and demonstrating self-training as a viable route to extend CoT length in harder reasoning tasks with structured distractors.

Abstract

The ability to reason lies at the core of artificial intelligence (AI), and challenging problems usually call for deeper and longer reasoning to tackle. A crucial question about AI reasoning is whether models can extrapolate learned reasoning patterns to solve harder tasks with longer chain-of-thought (CoT). In this work, we present a theoretical analysis of transformers learning on synthetic state-tracking tasks with gradient descent. We mathematically prove how the algebraic structure of state-tracking problems governs the degree of extrapolation of the learned CoT. Specifically, our theory characterizes the length generalization of transformers through the mechanism of attention concentration, linking the retrieval robustness of the attention layer to the state-tracking task structure of long-context reasoning. Moreover, for transformers with limited reasoning length, we prove that a recursive self-training scheme can progressively extend the range of solvable problem lengths. To our knowledge, we provide the first optimization guarantee that constant-depth transformers provably learn -complete problems with CoT, significantly going beyond prior art confined in , unless the widely held conjecture fails. Finally, we present a broad set of experiments supporting our theoretical results, confirming the length generalization behaviors and the mechanism of attention concentration.

Paper Structure

This paper contains 145 sections, 144 theorems, 388 equations, 7 figures, 2 algorithms.

Key Result

Theorem 1.1

One-layer transformers, trained via GD, can provably learn to solve state tracking problems for simply transitive and symmetry group actions via CoT reasoning.

Figures (7)

  • Figure 1: Empirical results of length generalization on LEGO tasks with different group actions. (a). Transformers length-generalize to solve significantly longer CoT tasks for simply transitive (cyclic) group (\ref{['thm:length-generalization']}), while generalizing poorly for symmetry group tasks. (b). When direct length generalization falls short for symmetry actions, a recursive self-training scheme that train on the model's own longer CoT traces bootstraps the solvable problem length (\ref{['thm:length-gen-self-training']}). The dashed lines indicate the training length.
  • Figure 2: Attention concentration at convergence for LEGO task with length $L=5$. The heatmap places the query clause index on the $y$-axis (keys on the $x$-axis). For a task of length $L$, the LEGO sequence prior to the final answer clause has length $2L$; we focus on query positions $L+1$ to $2L$, corresponding to answer clauses $Z_{\mathsf{ans},0}$ to $Z_{\mathsf{ans},L-1}$. Two diagonal bands in the upper region indicate attention concentrating on the answer clause $Z_{\mathsf{ans},\ell}$ and the predicate clause $Z_{\mathsf{pred},\ell+1}$ when the query is $Z_{\mathsf{ans},\ell}$.
  • Figure 3: Illustration of how the model solves the LEGO task: given $Z^{2,1}$, the goal is to predict $y_2$.
  • Figure 4: The illustration of how different components of the attention matrix $\mathbf{Q}$ are used to route the attention to the appropriate locations. The query clause is $\mathbf{Z}_{\mathsf{ans},1}$ and the goal is to retrieve the correct action $g_2$ from $\mathbf{Z}_{\mathsf{pred},2}$ and value $y_1$ from the current answer clause $\mathbf{Z}_{\mathsf{ans},1}$. $[\mathbf{Q}_{4,p}]_{s,s}$ will grow and dominate the learning dynamics for $p\in\{3,4\}$ and $s\in\tau(\mathcal{X})$. Thus, in this example, large $[\mathbf{Q}_{4,3}]_{\tau(x_1),\tau(x_1)}$ indicates the large attention to the predicate clause $\mathbf{Z}_{\mathsf{pred},2}$ and large $[\mathbf{Q}_{4,4}]_{\tau(x_1),\tau(x_1)}$ indicates the large self-attention to the answer clause $\mathbf{Z}_{\mathsf{ans},1}$.
  • Figure 5: Attention patterns of the same trained model as \ref{['fig:attn-con']}, evaluated with randomly permuted predicate-clause positions. Column (a) gives the ground-truth permutation. Column (b) and Column (c) show the attention heatmaps for the simply transitive and symmetry tasks, respectively.
  • ...and 2 more figures

Theorems & Definitions (264)

  • Theorem 1.1: Learning CoT, informal
  • Theorem 1.2: Length generalization, informal
  • Theorem 1.3: Recursive self-improvement, informal
  • Definition 2.1: Word problem for a group $G$
  • Theorem 2.1: Barrington barrington1986bounded
  • Definition 2.2: LEGO language zhang2022unveiling
  • Definition 3.1: LEGO encoding
  • Definition 3.2: Tokenization and token embedding
  • Definition 3.3: Embedding of LEGO sentences
  • Definition 3.4: Smooth ReLU
  • ...and 254 more