Table of Contents
Fetching ...

Transformers Learn to Implement Multi-step Gradient Descent with Chain of Thought

Jianhao Huang, Zixuan Wang, Jason D. Lee

TL;DR

This work investigates how Chain-of-Thought prompting shapes the training dynamics of transformers in an in-context weight prediction task for linear regression. It shows that CoT enables a single-layer linear transformer to autonomously implement multi-step gradient descent, achieving near-exact recovery of the true weight vector and demonstrating generalization to unseen data. The authors develop a gradient-flow analysis, proving global convergence under mild conditions and revealing a two-stage training dynamic, with improvements for looped transformers. Empirically, CoT prompts yield substantial gains over non-CoT baselines and corroborate the theoretical insights on both in-distribution and out-of-distribution data. The results illuminate CoT as a mechanism that guides optimization-like computations in transformers, suggesting broader implications for multi-step reasoning in neural models.

Abstract

Chain of Thought (CoT) prompting has been shown to significantly improve the performance of large language models (LLMs), particularly in arithmetic and reasoning tasks, by instructing the model to produce intermediate reasoning steps. Despite the remarkable empirical success of CoT and its theoretical advantages in enhancing expressivity, the mechanisms underlying CoT training remain largely unexplored. In this paper, we study the training dynamics of transformers over a CoT objective on an in-context weight prediction task for linear regression. We prove that while a one-layer linear transformer without CoT can only implement a single step of gradient descent (GD) and fails to recover the ground-truth weight vector, a transformer with CoT prompting can learn to perform multi-step GD autoregressively, achieving near-exact recovery. Furthermore, we show that the trained transformer effectively generalizes on the unseen data. With our technique, we also show that looped transformers significantly improve final performance compared to transformers without looping in the in-context learning of linear regression. Empirically, we demonstrate that CoT prompting yields substantial performance improvements.

Transformers Learn to Implement Multi-step Gradient Descent with Chain of Thought

TL;DR

This work investigates how Chain-of-Thought prompting shapes the training dynamics of transformers in an in-context weight prediction task for linear regression. It shows that CoT enables a single-layer linear transformer to autonomously implement multi-step gradient descent, achieving near-exact recovery of the true weight vector and demonstrating generalization to unseen data. The authors develop a gradient-flow analysis, proving global convergence under mild conditions and revealing a two-stage training dynamic, with improvements for looped transformers. Empirically, CoT prompts yield substantial gains over non-CoT baselines and corroborate the theoretical insights on both in-distribution and out-of-distribution data. The results illuminate CoT as a mechanism that guides optimization-like computations in transformers, suggesting broader implications for multi-step reasoning in neural models.

Abstract

Chain of Thought (CoT) prompting has been shown to significantly improve the performance of large language models (LLMs), particularly in arithmetic and reasoning tasks, by instructing the model to produce intermediate reasoning steps. Despite the remarkable empirical success of CoT and its theoretical advantages in enhancing expressivity, the mechanisms underlying CoT training remain largely unexplored. In this paper, we study the training dynamics of transformers over a CoT objective on an in-context weight prediction task for linear regression. We prove that while a one-layer linear transformer without CoT can only implement a single step of gradient descent (GD) and fails to recover the ground-truth weight vector, a transformer with CoT prompting can learn to perform multi-step GD autoregressively, achieving near-exact recovery. Furthermore, we show that the trained transformer effectively generalizes on the unseen data. With our technique, we also show that looped transformers significantly improve final performance compared to transformers without looping in the in-context learning of linear regression. Empirically, we demonstrate that CoT prompting yields substantial performance improvements.

Paper Structure

This paper contains 58 sections, 41 theorems, 296 equations, 3 figures.

Key Result

Theorem 3.1

If the global minimizer of $\mathcal{L}^{\mathrm{Eval}}(\bm{V},\bm{W})$ is $(\bm{V}^*,\bm{W}^*)$, the corresponding one-layer transformer $f_{\mathrm{LSA}}(\bm{Z}_0)_{[:,-1]}$ implements one step GD on a linear model with some learning rate $\eta^* = \frac{n}{n+d+1}$ and the transformer outputs $(\b

Figures (3)

  • Figure 1: Model weights: We present the heatmap of the weights of the trained transformer. We initialize $\bm{V},\bm{W}$ randomly at $t=0$, where $n=20$, $d=10$ and $k=20$. After training, all entries of $\bm{V}$ and $\bm{W}$ converge to zero except the two blocks highlighted in the red box. Moreover, the pattern matches the theoretical results.
  • Figure 2: $k$-step v.s. 1-step: We plot the evaluation loss $\mathcal{L}^{\mathrm{Eval}}$ when $n=20$, $d=10$. We randomly initialize the transformer. For transformers with CoT, loss converges to near zero while transformers without CoT cannot. Moreover, the loss at convergence decreases when $k$ increases.
  • Figure 3: OOD Generalization: We plot the OOD loss $\mathcal{L}^{\mathrm{Eval}}_{\bm{\Sigma}}$ when $n=20$, $d=10$. Each set of experiments sampled 10 different $\bm{\Sigma}$. The mean results are presented as line charts, with variance represented by shaded areas. As shown, OOD loss will converge to near zero.

Theorems & Definitions (75)

  • Theorem 3.1: Lower bound without CoT
  • Corollary 3.1
  • proof
  • Theorem 3.2: Informal
  • Theorem 4.1: Informal, Global Convergence
  • Lemma 4.1: Informal version of \ref{['lemma: gradient components of the reduced model']}
  • Theorem 4.2: Informal, \ref{['appendix theorem: evaluation']}
  • Theorem 4.3: Informal, \ref{['thm: global convergence for loop tf']}
  • Lemma A.1: Corollary A.4 in gatmiry2024can
  • Lemma A.2: Equation (24) in gatmiry2024can
  • ...and 65 more