Table of Contents
Fetching ...

From Sparse Dependence to Sparse Attention: Unveiling How Chain-of-Thought Enhances Transformer Sample Efficiency

Kaiyue Wen, Huaqing Zhang, Hongzhou Lin, Jingzhao Zhang

TL;DR

This work investigates why chain-of-thought prompts improve reasoning in transformers, arguing that sample efficiency—not just expressiveness—is the bottleneck. Through a parity-function framework, it shows exponential sample complexity without CoT but polynomial, near-linear complexity with CoT, driven by sparse sequential dependencies that yield sparse, interpretable attention. Theoretical results are complemented by empirical parity experiments and real-world GSM8K data, which confirm the central role of sparsity in attention for CoT-enabled learning. The findings suggest that CoT data shapes the optimization landscape, enabling efficient generalization and interpretable representations in attention mechanisms.

Abstract

Chain-of-thought (CoT) significantly enhances the reasoning performance of large language models (LLM). While current theoretical studies often attribute this improvement to increased expressiveness and computational capacity, we argue that expressiveness is not the primary limitation in the LLM regime, as current large models will fail on simple tasks. Using a parity-learning setup, we demonstrate that CoT can substantially improve sample efficiency even when the representation power is sufficient. Specifically, with CoT, a transformer can learn the function within polynomial samples, whereas without CoT, the required sample size is exponential. Additionally, we show that CoT simplifies the learning process by introducing sparse sequential dependencies among input tokens, and leads to a sparse and interpretable attention. We validate our theoretical analysis with both synthetic and real-world experiments, confirming that sparsity in attention layers is a key factor of the improvement induced by CoT.

From Sparse Dependence to Sparse Attention: Unveiling How Chain-of-Thought Enhances Transformer Sample Efficiency

TL;DR

This work investigates why chain-of-thought prompts improve reasoning in transformers, arguing that sample efficiency—not just expressiveness—is the bottleneck. Through a parity-function framework, it shows exponential sample complexity without CoT but polynomial, near-linear complexity with CoT, driven by sparse sequential dependencies that yield sparse, interpretable attention. Theoretical results are complemented by empirical parity experiments and real-world GSM8K data, which confirm the central role of sparsity in attention for CoT-enabled learning. The findings suggest that CoT data shapes the optimization landscape, enabling efficient generalization and interpretable representations in attention mechanisms.

Abstract

Chain-of-thought (CoT) significantly enhances the reasoning performance of large language models (LLM). While current theoretical studies often attribute this improvement to increased expressiveness and computational capacity, we argue that expressiveness is not the primary limitation in the LLM regime, as current large models will fail on simple tasks. Using a parity-learning setup, we demonstrate that CoT can substantially improve sample efficiency even when the representation power is sufficient. Specifically, with CoT, a transformer can learn the function within polynomial samples, whereas without CoT, the required sample size is exponential. Additionally, we show that CoT simplifies the learning process by introducing sparse sequential dependencies among input tokens, and leads to a sparse and interpretable attention. We validate our theoretical analysis with both synthetic and real-world experiments, confirming that sparsity in attention layers is a key factor of the improvement induced by CoT.
Paper Structure (30 sections, 35 theorems, 146 equations, 11 figures, 1 table)

This paper contains 30 sections, 35 theorems, 146 equations, 11 figures, 1 table.

Key Result

Theorem 1

Consider the Transformer model defined in def:trans, for any $\delta<0.1$ and large enough $n$, when the number of secret indices $k$ is in $[n / \log^5 (n / \delta), n / \log^4(n/\delta)]$, with probability at least $1 - \delta$ over the randomness of embedding $e$, there exists a weight configurat

Figures (11)

  • Figure 1: (a) We show that, without Chain-of-Thought (CoT), the sample complexity for training transformers to learn the parity function grows exponentially with the hardness parameter $k$. In contrast, utilizing CoT significantly improves sample efficiency. (b) We also show that the sparsity of attention layers, measured by normalized entropy (\ref{['eq: normalized attention entropy']}), is crucial in the parity learning experiment. In both CoT and non-CoT scenarios, as the attention layers become sparser—indicated by a rapid decrease in normalized entropy—a corresponding jump in evaluation accuracy occurs.
  • Figure 2: The sample complexity for learning parity without CoT increases exponentially with $k$. CoT significantly reduces the sample complexity, demonstrating exponential improvement across varying numbers of heads and layers.
  • Figure 3: (a) When $n$ is fixed, the sample complexity of learning parity with CoT grows approximately linearly with $k$. (b) The attention pattern learned by the transformer with CoT is interpretable, as the $i$-th output token of CoT predominantly attends to secret index $S[i]$.
  • Figure 4: Evaluation accuracy of transformers on the $(n=20, k=6)$ parity problem without CoT. The model is trained on a dataset of $10,000$ samples for $1,000$ epochs. Almost all layer-head configuration achieve perfect evaluation accuracy. Adding more heads is more effective than adding layers. The blue dashed line marks the with CoT setup, which achieves perfect accuracy in 5 epochs.
  • Figure 5: $4$-layer $4$-head transformer trained on the $(n=20, k=6)$ parity problem without CoT using multi-pass training, detailed in \ref{['sec:multipass']}. When trained on a small dataset of $50000$ samples, the model achieves perfect evaluation accuracy (Top), accompanied by a significant decrease in entropy. Surprisingly, when trained on an even larger training set with $1000000$ samples, the model fails to learn (Bottom), and both the training loss and the normalized attention entropy remain elevated. This shows that the development of attention sparsity may improve optimization.
  • ...and 6 more figures

Theorems & Definitions (75)

  • Definition 1: Parity Problem $(n,k)$ without CoT
  • Definition 2: Parity Problem $(n,k)$ with CoT
  • Definition 3: Embedding Module
  • Definition 4: Attention Module
  • Definition 5
  • Definition 6: Simplified Transformer Block
  • Theorem 1: Easy to Represent
  • Theorem 2: Hard to Learn
  • Theorem 3: Easy to Learn with CoT
  • proof : Proof of \ref{['thm:rep']}
  • ...and 65 more