Table of Contents
Fetching ...

Transformers Provably Solve Parity Efficiently with Chain of Thought

Juno Kim, Taiji Suzuki

TL;DR

The paper analyzes how transformers can be trained to solve parity by recursively generating intermediate states, i.e., chain-of-thought reasoning. It proves that parity is hard for any finite-sample gradient-based method without intermediate supervision, and shows that teacher forcing enables parity to be learned in a single gradient update. It further demonstrates that end-to-end CoT with data augmentation and self-consistency checks can achieve efficient learning in a logarithmic number of steps, even without teacher forcing. Numerical experiments support the theory, revealing a phased, hierarchical learning process and the value of process supervision for complex multi-step tasks. Overall, the work provides a theoretical blueprint for enabling task decomposition and robust multi-step reasoning in transformers through CoT training.

Abstract

This work provides the first theoretical analysis of training transformers to solve complex problems by recursively generating intermediate states, analogous to fine-tuning for chain-of-thought (CoT) reasoning. We consider training a one-layer transformer to solve the fundamental $k$-parity problem, extending the work on RNNs by Wies et al. (2023). We establish three key results: (1) any finite-precision gradient-based algorithm, without intermediate supervision, requires substantial iterations to solve parity with finite samples. (2) In contrast, when intermediate parities are incorporated into the loss function, our model can learn parity in one gradient update when aided by \emph{teacher forcing}, where ground-truth labels of the reasoning chain are provided at each generation step. (3) Even without teacher forcing, where the model must generate CoT chains end-to-end, parity can be learned efficiently if augmented data is employed to internally verify the soundness of intermediate steps. Our findings, supported by numerical experiments, show that task decomposition and stepwise reasoning naturally arise from optimizing transformers with CoT; moreover, self-consistency checking can improve multi-step reasoning ability, aligning with empirical studies of CoT.

Transformers Provably Solve Parity Efficiently with Chain of Thought

TL;DR

The paper analyzes how transformers can be trained to solve parity by recursively generating intermediate states, i.e., chain-of-thought reasoning. It proves that parity is hard for any finite-sample gradient-based method without intermediate supervision, and shows that teacher forcing enables parity to be learned in a single gradient update. It further demonstrates that end-to-end CoT with data augmentation and self-consistency checks can achieve efficient learning in a logarithmic number of steps, even without teacher forcing. Numerical experiments support the theory, revealing a phased, hierarchical learning process and the value of process supervision for complex multi-step tasks. Overall, the work provides a theoretical blueprint for enabling task decomposition and robust multi-step reasoning in transformers through CoT training.

Abstract

This work provides the first theoretical analysis of training transformers to solve complex problems by recursively generating intermediate states, analogous to fine-tuning for chain-of-thought (CoT) reasoning. We consider training a one-layer transformer to solve the fundamental -parity problem, extending the work on RNNs by Wies et al. (2023). We establish three key results: (1) any finite-precision gradient-based algorithm, without intermediate supervision, requires substantial iterations to solve parity with finite samples. (2) In contrast, when intermediate parities are incorporated into the loss function, our model can learn parity in one gradient update when aided by \emph{teacher forcing}, where ground-truth labels of the reasoning chain are provided at each generation step. (3) Even without teacher forcing, where the model must generate CoT chains end-to-end, parity can be learned efficiently if augmented data is employed to internally verify the soundness of intermediate steps. Our findings, supported by numerical experiments, show that task decomposition and stepwise reasoning naturally arise from optimizing transformers with CoT; moreover, self-consistency checking can improve multi-step reasoning ability, aligning with empirical studies of CoT.

Paper Structure

This paper contains 31 sections, 10 theorems, 100 equations, 5 figures.

Key Result

Theorem 1

Let $\ell_{0-1}$ be the zero-one loss. There exists an $O(e^{-d/3})$-approximate oracle $\widetilde{\nabla}$ such thatThe original paper states that $\mathcal{A}$ can be any iterative gradient-based algorithm which receives an $\Omega(e^{-d/3})$-approximation of the gradient at each step. However, t with probability at least $1-O(e^{-d/3})$, when the target parity $p$ is uniformly sampled from $P$

Figures (5)

  • Figure 1: A hierarchical decomposition of an $8$-parity problem for $d=16$. Here $x_{17}=x_1x_4$ so that $\mathsf{c}_1[17]=1$, $\mathsf{c}_2[17]=4$, $\mathsf{p}[17]=21$ and $\mathsf{h}[17]=1$.
  • Figure 2: Illustration of the recursive data generation process by the transformer model. (a) Each token consists of a one-hot positional encoding $\bm{e}_j$ and parity data $\bm{x}_j$. The $d$ input tokens (blue) are fixed. The token $\hat{\bm{x}}_m$ is generated at the $(m-d)$th step by computing attention scores based on position, combining the previous tokens and applying the feedforward layer $\phi$. $\hat{\bm{x}}_{d+k-1}$ is returned as the model prediction. (b) For the no teacher forcing setup in Section \ref{['sec:without']}, data augmentation $\bm{u}_j$ is implemented to check for self-consistency. If the augmented outputs from the previous generation (red) are uninformative, a filter $\iota$ is applied to zero out the subsequent output.
  • Figure 3: Causal mask for $\mathbf{W}^\top$ with teacher forcing (left); without teacher forcing (right). The gray entries are set to $-\infty$.
  • Figure 4: CoT loss (left) and prediction loss (right) curves for the four models when $d=64$, $k=32$. For the CoT+consistency model, dashed lines indicate when the filters of each level are deactivated.
  • Figure 5: CoT loss (left) and prediction loss (right) curves for the four models when $d=64$, $k=32$ (top), $k=16$ (middle) and $k=8$ (bottom). For the CoT+consistency model, dashed lines indicate when the filters of each level are deactivated.

Theorems & Definitions (13)

  • Theorem 1: Wies23, Theorem 4
  • Theorem 2: hardness of finite-sample parity
  • Proposition 3: Shai17, Theorem 1
  • Lemma 4
  • Theorem 5: CoT with teacher forcing
  • Lemma 6
  • Theorem 7: CoT without teacher forcing
  • Lemma 8
  • proof
  • Lemma 9: concentration of interaction terms
  • ...and 3 more