Metastable Dynamics of Chain-of-Thought Reasoning: Provable Benefits of Search, RL and Distillation
Juno Kim, Denny Wu, Jason Lee, Taiji Suzuki
TL;DR
This work presents a metastable Markov framework for chain-of-thought reasoning, revealing how reasoning naturally clusters into dense easy-steps and sparse hard-steps that drive long-time dynamics. It analyzes how inference-time search, reinforcement learning, and distillation can meaningfully accelerate reaching correct reasoning clusters, by identifying and amplifying sparse inter-cluster edges and compressing cluster transitions into a meta-chain. The authors prove hitting-time bounds and provide convergence guarantees for pretraining, sparse-reward search (PRM), PPO-Clip fine-tuning, and distillation, along with learning-theoretic limits (SDA) showing that global information is necessary for certain logical tasks. These results justify practical inference-time strategies to enhance CoT in large language models and inform how to distill reasoning patterns into smaller, efficient representations. The work offers a principled perspective on the trade-offs between search, RL, and distillation and suggests directions for future exploration of inference-time computation and its scaling laws.
Abstract
A key paradigm to improve the reasoning capabilities of large language models (LLMs) is to allocate more inference-time compute to search against a verifier or reward model. This process can then be utilized to refine the pretrained model or distill its reasoning patterns into more efficient models. In this paper, we study inference-time compute by viewing chain-of-thought (CoT) generation as a metastable Markov process: easy reasoning steps (e.g., algebraic manipulations) form densely connected clusters, while hard reasoning steps (e.g., applying a relevant theorem) create sparse, low-probability edges between clusters, leading to phase transitions at longer timescales. Under this framework, we prove that implementing a search protocol that rewards sparse edges improves CoT by decreasing the expected number of steps to reach different clusters. In contrast, we establish a limit on reasoning capability when the model is restricted to local information of the pretrained graph. We also show that the information gained by search can be utilized to obtain a better reasoning model: (1) the pretrained model can be directly finetuned to favor sparse edges via policy gradient methods, and moreover (2) a compressed metastable representation of the reasoning dynamics can be distilled into a smaller, more efficient model.
