Table of Contents
Fetching ...

Metastable Dynamics of Chain-of-Thought Reasoning: Provable Benefits of Search, RL and Distillation

Juno Kim, Denny Wu, Jason Lee, Taiji Suzuki

TL;DR

This work presents a metastable Markov framework for chain-of-thought reasoning, revealing how reasoning naturally clusters into dense easy-steps and sparse hard-steps that drive long-time dynamics. It analyzes how inference-time search, reinforcement learning, and distillation can meaningfully accelerate reaching correct reasoning clusters, by identifying and amplifying sparse inter-cluster edges and compressing cluster transitions into a meta-chain. The authors prove hitting-time bounds and provide convergence guarantees for pretraining, sparse-reward search (PRM), PPO-Clip fine-tuning, and distillation, along with learning-theoretic limits (SDA) showing that global information is necessary for certain logical tasks. These results justify practical inference-time strategies to enhance CoT in large language models and inform how to distill reasoning patterns into smaller, efficient representations. The work offers a principled perspective on the trade-offs between search, RL, and distillation and suggests directions for future exploration of inference-time computation and its scaling laws.

Abstract

A key paradigm to improve the reasoning capabilities of large language models (LLMs) is to allocate more inference-time compute to search against a verifier or reward model. This process can then be utilized to refine the pretrained model or distill its reasoning patterns into more efficient models. In this paper, we study inference-time compute by viewing chain-of-thought (CoT) generation as a metastable Markov process: easy reasoning steps (e.g., algebraic manipulations) form densely connected clusters, while hard reasoning steps (e.g., applying a relevant theorem) create sparse, low-probability edges between clusters, leading to phase transitions at longer timescales. Under this framework, we prove that implementing a search protocol that rewards sparse edges improves CoT by decreasing the expected number of steps to reach different clusters. In contrast, we establish a limit on reasoning capability when the model is restricted to local information of the pretrained graph. We also show that the information gained by search can be utilized to obtain a better reasoning model: (1) the pretrained model can be directly finetuned to favor sparse edges via policy gradient methods, and moreover (2) a compressed metastable representation of the reasoning dynamics can be distilled into a smaller, more efficient model.

Metastable Dynamics of Chain-of-Thought Reasoning: Provable Benefits of Search, RL and Distillation

TL;DR

This work presents a metastable Markov framework for chain-of-thought reasoning, revealing how reasoning naturally clusters into dense easy-steps and sparse hard-steps that drive long-time dynamics. It analyzes how inference-time search, reinforcement learning, and distillation can meaningfully accelerate reaching correct reasoning clusters, by identifying and amplifying sparse inter-cluster edges and compressing cluster transitions into a meta-chain. The authors prove hitting-time bounds and provide convergence guarantees for pretraining, sparse-reward search (PRM), PPO-Clip fine-tuning, and distillation, along with learning-theoretic limits (SDA) showing that global information is necessary for certain logical tasks. These results justify practical inference-time strategies to enhance CoT in large language models and inform how to distill reasoning patterns into smaller, efficient representations. The work offers a principled perspective on the trade-offs between search, RL, and distillation and suggests directions for future exploration of inference-time computation and its scaling laws.

Abstract

A key paradigm to improve the reasoning capabilities of large language models (LLMs) is to allocate more inference-time compute to search against a verifier or reward model. This process can then be utilized to refine the pretrained model or distill its reasoning patterns into more efficient models. In this paper, we study inference-time compute by viewing chain-of-thought (CoT) generation as a metastable Markov process: easy reasoning steps (e.g., algebraic manipulations) form densely connected clusters, while hard reasoning steps (e.g., applying a relevant theorem) create sparse, low-probability edges between clusters, leading to phase transitions at longer timescales. Under this framework, we prove that implementing a search protocol that rewards sparse edges improves CoT by decreasing the expected number of steps to reach different clusters. In contrast, we establish a limit on reasoning capability when the model is restricted to local information of the pretrained graph. We also show that the information gained by search can be utilized to obtain a better reasoning model: (1) the pretrained model can be directly finetuned to favor sparse edges via policy gradient methods, and moreover (2) a compressed metastable representation of the reasoning dynamics can be distilled into a smaller, more efficient model.

Paper Structure

This paper contains 47 sections, 42 theorems, 177 equations, 3 figures, 4 algorithms.

Key Result

Proposition 2.1

Any subset $S_\circ=\{x_1,\cdots,x_K\}\subset S$ of cluster representatives $x_k\in C_k$ constitutes a metastable system for $X^\varepsilon$ in the sense of eq:bovier as $M\to\infty$.

Figures (3)

  • Figure 1: (Left) Example of metastable graph with three clusters. Each state represents a logical assertion and edges correspond to reasoning steps. Solid and dashed arrows indicate easy (within-cluster) and hard (inter-cluster) reasoning steps, respectively. The goal of the reasoner is to retrieve a valid CoT path from $X_{\mathop{\mathrm{in}}\nolimits}$ to $X_{\mathop{\mathrm{out}}\nolimits}$ (highlighted). Search aims to use CoT generated from the pretrained model to explore the linguistic model and identify hard steps, which can then be used to fine-tune the pretrained model via RL to improve its generation. (Right) The coarse-grained dynamics of CoT at long timescales can be represented by a meta-chain on the set of clusters and distilled into a smaller model, which can generate reasoning paths more efficiently.
  • Figure 2: Sparse edge construction for the no-search scenario. The two circles represent the original dense clusters. Dashed edges have probability $\varepsilon$.
  • Figure 3: Graph construction for the local search scenario ($r=2$). Dashed edges have probability $\varepsilon$. A local neighborhood of maximum distance one from the original graph is shown in the dashed box, which the learner is assumed to have full access to.

Theorems & Definitions (77)

  • Proposition 2.1
  • Theorem 3.1: convergence of pretraining
  • Theorem 3.2: expected hitting time
  • Proposition 3.3
  • Proposition 3.4: convergence of PPO-Clip
  • Proposition 4.1
  • Proposition 4.2: convergence of distillation
  • Theorem 4.3: hitting time of distilled CoT
  • Definition 5.1: SDA: SQDIM with access
  • Theorem 5.2: SQ learning with additional information
  • ...and 67 more