Transformers Provably Learn Chain-of-Thought Reasoning with Length Generalization
Yu Huang, Zixin Wen, Aarti Singh, Yuejie Chi, Yuxin Chen
TL;DR
The paper investigates whether gradient-descent trained transformers can learn chain-of-thought (CoT) reasoning with length generalization on synthetic LEGO state-tracking tasks. Using a minimal one-layer NoPE transformer with softmax attention and an FFN, it analyzes two algebraic action structures—simply transitive and symmetry group actions—and proves three main results: provable CoT learning with length generalization for simply transitive tasks, a mechanistic dependence of length generalization on the task structure via attention concentration, and recursive self-training that provably extends reasoning length for symmetry tasks. The results bridge expressiveness gaps between TC$^0$ and NC$^1$ under CoT, and provide optimization-based guarantees that gradient-based training can discover length-generalizing CoT solutions, complemented by empirical LEGO experiments. The empirical findings validate the theory, showing distinct length-generalization patterns and attention-concentration mechanisms, and demonstrating self-training as a viable route to extend CoT length in harder reasoning tasks with structured distractors.
Abstract
The ability to reason lies at the core of artificial intelligence (AI), and challenging problems usually call for deeper and longer reasoning to tackle. A crucial question about AI reasoning is whether models can extrapolate learned reasoning patterns to solve harder tasks with longer chain-of-thought (CoT). In this work, we present a theoretical analysis of transformers learning on synthetic state-tracking tasks with gradient descent. We mathematically prove how the algebraic structure of state-tracking problems governs the degree of extrapolation of the learned CoT. Specifically, our theory characterizes the length generalization of transformers through the mechanism of attention concentration, linking the retrieval robustness of the attention layer to the state-tracking task structure of long-context reasoning. Moreover, for transformers with limited reasoning length, we prove that a recursive self-training scheme can progressively extend the range of solvable problem lengths. To our knowledge, we provide the first optimization guarantee that constant-depth transformers provably learn $\mathsf{NC}^1$-complete problems with CoT, significantly going beyond prior art confined in $\mathsf{TC}^0$, unless the widely held conjecture $\mathsf{TC}^0 \neq \mathsf{NC}^1$ fails. Finally, we present a broad set of experiments supporting our theoretical results, confirming the length generalization behaviors and the mechanism of attention concentration.
