Table of Contents
Fetching ...

Multi-head Transformers Provably Learn Symbolic Multi-step Reasoning via Gradient Descent

Tong Yang, Yu Huang, Yingbin Liang, Yuejie Chi

TL;DR

This work theoretically shows that a one-layer, multi-head transformer can learn symbolic, multi-step reasoning for path-finding on trees using chain-of-thought. It provides explicit constructions and gradient-descent analyses for backward (goal-to-root) and forward (root-to-goal) tasks, including a two-head architecture and stage-switching to coordinate two subtasks within a single autoregressive pass. The results establish convergence and generalization guarantees to unseen trees, offering a mechanistic explanation for how CoT enables shallow transformers to emulate sequential algorithms. These findings illuminate how intermediate reasoning steps can empower shallow models, with implications for understanding why larger models exhibit emergent reasoning as task complexity and trace length increase.

Abstract

Transformers have demonstrated remarkable capabilities in multi-step reasoning tasks. However, understandings of the underlying mechanisms by which they acquire these abilities through training remain limited, particularly from a theoretical standpoint. This work investigates how transformers learn to solve symbolic multi-step reasoning problems through chain-of-thought processes, focusing on path-finding in trees. We analyze two intertwined tasks: a backward reasoning task, where the model outputs a path from a goal node to the root, and a more complex forward reasoning task, where the model implements two-stage reasoning by first identifying the goal-to-root path and then reversing it to produce the root-to-goal path. Our theoretical analysis, grounded in the dynamics of gradient descent, shows that trained one-layer transformers can provably solve both tasks with generalization guarantees to unseen trees. In particular, our multi-phase training dynamics for forward reasoning elucidate how different attention heads learn to specialize and coordinate autonomously to solve the two subtasks in a single autoregressive path. These results provide a mechanistic explanation of how trained transformers can implement sequential algorithmic procedures. Moreover, they offer insights into the emergence of reasoning abilities, suggesting that when tasks are structured to take intermediate chain-of-thought steps, even shallow multi-head transformers can effectively solve problems that would otherwise require deeper architectures.

Multi-head Transformers Provably Learn Symbolic Multi-step Reasoning via Gradient Descent

TL;DR

This work theoretically shows that a one-layer, multi-head transformer can learn symbolic, multi-step reasoning for path-finding on trees using chain-of-thought. It provides explicit constructions and gradient-descent analyses for backward (goal-to-root) and forward (root-to-goal) tasks, including a two-head architecture and stage-switching to coordinate two subtasks within a single autoregressive pass. The results establish convergence and generalization guarantees to unseen trees, offering a mechanistic explanation for how CoT enables shallow transformers to emulate sequential algorithms. These findings illuminate how intermediate reasoning steps can empower shallow models, with implications for understanding why larger models exhibit emergent reasoning as task complexity and trace length increase.

Abstract

Transformers have demonstrated remarkable capabilities in multi-step reasoning tasks. However, understandings of the underlying mechanisms by which they acquire these abilities through training remain limited, particularly from a theoretical standpoint. This work investigates how transformers learn to solve symbolic multi-step reasoning problems through chain-of-thought processes, focusing on path-finding in trees. We analyze two intertwined tasks: a backward reasoning task, where the model outputs a path from a goal node to the root, and a more complex forward reasoning task, where the model implements two-stage reasoning by first identifying the goal-to-root path and then reversing it to produce the root-to-goal path. Our theoretical analysis, grounded in the dynamics of gradient descent, shows that trained one-layer transformers can provably solve both tasks with generalization guarantees to unseen trees. In particular, our multi-phase training dynamics for forward reasoning elucidate how different attention heads learn to specialize and coordinate autonomously to solve the two subtasks in a single autoregressive path. These results provide a mechanistic explanation of how trained transformers can implement sequential algorithmic procedures. Moreover, they offer insights into the emergence of reasoning abilities, suggesting that when tasks are structured to take intermediate chain-of-thought steps, even shallow multi-head transformers can effectively solve problems that would otherwise require deeper architectures.

Paper Structure

This paper contains 120 sections, 31 theorems, 512 equations, 8 figures.

Key Result

Theorem 1

Under Assumption asmp:construct, for any $\alpha\in\mathbb{R}$, there exists $B= B_\alpha \in\mathbb{R}^{d_1\times d_1}$ such that Let $\theta=\{B_\alpha\}$, then for any tree ${\mathcal{T}}$, we have $\widehat{O}_\mathsf{g2r}({\mathcal{T}};\theta)\rightarrow O_\mathsf{g2r}({\mathcal{T}})$ as $\alpha\rightarrow +\infty$.

Figures (8)

  • Figure 1: An illustration of the path-finding reasoning task in a tree. Solving the forward task (finding the root-to-goal path) requires solving the backward task (finding the goal-to-root path) first.
  • Figure 2: The multi-step reasoning process of the constructed transformers for the backward and forward reasoning tasks. Color indicates the attention association and output in each step.
  • Figure 3: An example of a perfect binary tree of depth $m=2$ and distinct nodes.
  • Figure 4: Training and test loss curves for backward reasoning.
  • Figure 5: Training dynamics of selected entries of $H$.
  • ...and 3 more figures

Theorems & Definitions (32)

  • Theorem 1: Construction for backward reasoning
  • Theorem 2: Construction for forward reasoning
  • Theorem 3: Convergence of backward reasoning
  • Theorem 4: Generalization of backward reasoning
  • Theorem 5: Training dynamics of forward reasoning
  • Theorem 6: Generalization of forward reasoning
  • Lemma 1
  • proof
  • Lemma 2: Loss simplification
  • Lemma 3: Gradient computation
  • ...and 22 more