Table of Contents
Fetching ...

Transformers with RL or SFT Provably Learn Sparse Boolean Functions, But Differently

Bochen Lyu, Yiyang Jia, Xiaohao Cai, Zhanxing Zhu

TL;DR

This work analyzes how RL and SFT enable chain-of-thought reasoning in a one-layer transformer to learn $k$-sparse Boolean functions via recursive 2-sparse decompositions. It derives sufficient conditions under which provable learning is achieved for both RL (with immediate rewards) and SFT (without teacher forcing), and validates these conditions on $k$-PARITY, $k$-AND, and $k$-OR. A key finding is that RL can learn the entire CoT chain in a single gradient update, while SFT learns the chain step-by-step, reflecting intrinsic differences in supervision signals. The results provide mechanistic insights into how CoT emerges under RL versus SFT and offer guidance for designing reasoning-based fine-tuning regimes in transformers.

Abstract

Transformers can acquire Chain-of-Thought (CoT) capabilities to solve complex reasoning tasks through fine-tuning. Reinforcement learning (RL) and supervised fine-tuning (SFT) are two primary approaches to this end, yet their underlying mechanisms and differences remain theoretically unclear. In this work, we examine these aspects specifically for learning $k$-sparse Boolean functions with a one-layer transformer and intermediate supervision that is akin to CoT. In particular, we consider $k$-sparse Boolean functions that can be recursively decomposed into fixed 2-sparse Boolean functions. We analyze the learning dynamics of fine-tuning the transformer via either RL or SFT with CoT to identify sufficient conditions for it to provably learn these functions. We verify that these conditions hold for three basic examples, including $k$-PARITY, $k$-AND, and $k$-OR, thus demonstrating the learnability of both approaches. Notably, we reveal that RL and SFT exhibit distinct learning behaviors: RL learns the whole CoT chain simultaneously, whereas SFT learns the CoT chain step-by-step. Overall, our findings provide theoretical insights into the underlying mechanisms of RL and SFT as well as how they differ in triggering the CoT capabilities of transformers.

Transformers with RL or SFT Provably Learn Sparse Boolean Functions, But Differently

TL;DR

This work analyzes how RL and SFT enable chain-of-thought reasoning in a one-layer transformer to learn -sparse Boolean functions via recursive 2-sparse decompositions. It derives sufficient conditions under which provable learning is achieved for both RL (with immediate rewards) and SFT (without teacher forcing), and validates these conditions on -PARITY, -AND, and -OR. A key finding is that RL can learn the entire CoT chain in a single gradient update, while SFT learns the chain step-by-step, reflecting intrinsic differences in supervision signals. The results provide mechanistic insights into how CoT emerges under RL versus SFT and offer guidance for designing reasoning-based fine-tuning regimes in transformers.

Abstract

Transformers can acquire Chain-of-Thought (CoT) capabilities to solve complex reasoning tasks through fine-tuning. Reinforcement learning (RL) and supervised fine-tuning (SFT) are two primary approaches to this end, yet their underlying mechanisms and differences remain theoretically unclear. In this work, we examine these aspects specifically for learning -sparse Boolean functions with a one-layer transformer and intermediate supervision that is akin to CoT. In particular, we consider -sparse Boolean functions that can be recursively decomposed into fixed 2-sparse Boolean functions. We analyze the learning dynamics of fine-tuning the transformer via either RL or SFT with CoT to identify sufficient conditions for it to provably learn these functions. We verify that these conditions hold for three basic examples, including -PARITY, -AND, and -OR, thus demonstrating the learnability of both approaches. Notably, we reveal that RL and SFT exhibit distinct learning behaviors: RL learns the whole CoT chain simultaneously, whereas SFT learns the CoT chain step-by-step. Overall, our findings provide theoretical insights into the underlying mechanisms of RL and SFT as well as how they differ in triggering the CoT capabilities of transformers.

Paper Structure

This paper contains 61 sections, 11 theorems, 133 equations, 4 figures, 1 table.

Key Result

Theorem 3.1

Given integers $d \geq k \geq 2$, consider a $k$-sparse Boolean function $\Phi_k(\cdot)$ with any subset $B \subseteq [d]$ as in Def. def:bool_funcs. Let ${\bm{W}}(0) = \mathbf{1}$ be the initialization and let be the optimal parameter that solves $\max_{{\bm{W}}}\mathcal{R}({\bm{W}})$. Set learning rate $\eta = \Omega\left( \ln(d/\epsilon)\right)$ for any $\epsilon > 0$. If the separation of the

Figures (4)

  • Figure 1: Recursive decomposition of solving $\Phi_k({\mathbf{x}})$
  • Figure 3: The pretrained transformer iteratively uses its output to solve $\Phi_k({\mathbf{x}})$ in a CoT manner.
  • Figure 4: (a) The ground truth $(\sigma^{\star})_{N_{t - 1} + l^{(t)}}^{N_{t - 2} + p}$. Each white box is $0.5$ and each gray box is $0$. (b)$\mathop{\mathrm{sign}}\nolimits(\nabla_{{\bm{W}}}{\mathcal{L}}({\bm{W}}))$ at ${\bm{W}}(0) = \mathbf{1}$. Each white box has value $+1$ and each black box has value $-1$. Gray boxes have value 0 coming from causal mask and pretrained mask.
  • Figure 5: $\mathop{\mathrm{sign}}\nolimits(- \nabla_{{\bm{W}}}{\mathcal{L}}({\bm{W}}))$ computed by ${\bm{W}}(s)$ for different updating step $s$. Each white box has value $+1$, each black box has value $-1$, and grey boxes are 0.

Theorems & Definitions (22)

  • Definition 2.1: $k$-sparse Boolean functions
  • Theorem 3.1: Learnability of $k$-sparse Boolean functions via RL
  • Proposition 3.1: Hardness of RL with final reward
  • Theorem 3.2: Learnability of $k$-sparse Boolean functions via SFT
  • Theorem 4.1: $k$-PARITY learning dynamics of fine-tuning via RL
  • Claim 4.1: Transformers with CoT can learn $k$-PARITY via SFT
  • Claim 4.2: Transformers with CoT can learn $k$-AND and $k$-OR via RL or SFT
  • Lemma 1: Formulation of the policy gradient
  • proof
  • Lemma 2
  • ...and 12 more