Emergence of Superposition: Unveiling the Training Dynamics of Chain of Continuous Thought
Hanlin Zhu, Shibo Hao, Zhiting Hu, Jiantao Jiao, Stuart Russell, Yuandong Tian
TL;DR
This work analyzes how continuous chain-of-thought (continuous CoT) can emerge from gradient-based training in a simplified two-layer transformer for directed graph reachability. By introducing Coconut training and separating the process into thought generation and final prediction, the authors show that the index-matching logit $\mu(t)$ remains bounded under mild conditions, which fosters a balanced exploration-exploitation dynamic and leads to superposition of multiple reasoning traces. They derive exact forward passes, loss formulations, and gradient dynamics, establishing that a positive, finite $\mu$ enables one-step frontier expansion and enables the model to maintain plausible multiple traces in parallel when uncertain. In the prediction stage, two signals, residual carryover $\mu_A$ and candidate lift $\mu_R$, interact to robustly elevate the correct destination node; gradient-flow analysis yields a maximum-margin bias, enriching the model’s ability to generalize to unseen graphs. Experimental results on a ProsQA-derived dataset validate the theoretical predictions, showing bounded logits, rapid adaptation across reasoning depths, and high final accuracy, thus offering a mechanistic explanation for the observed benefits of continuous CoT and guidance for scalable training of latent reasoning. The findings have potential impact on scalable reasoning in large language models by clarifying how latent reasoning traces can be learned and leveraged during gradient-based training.
Abstract
Previous work shows that the chain of continuous thought (continuous CoT) improves the reasoning capability of large language models (LLMs) by enabling implicit parallel thinking, and a subsequent work provided theoretical insight by showing that a two-layer transformer equipped with continuous CoT can efficiently solve directed graph reachability by maintaining a superposition of multiple reasoning traces in the continuous thought. However, it remains unclear how the superposition mechanism is naturally learned from gradient-based training methods. To fill this gap, we theoretically analyze the training dynamics of a simplified two-layer transformer on the directed graph reachability problem to unveil how the superposition mechanism emerges during training in two training stages -- (i) a thought-generation stage that autoregressively expands the continuous thought, and (ii) a prediction stage that converts the thought into the final answer. Our analysis reveals that during training using continuous thought, the index-matching logit, an important quantity which reflects the strength of the model's local search ability, will first increase and then remain bounded under mild assumptions. The bounded index-matching logit effectively balances exploration and exploitation during the reasoning process: the model will exploit local problem structures to identify plausible search traces, and assign comparable weights to multiple such traces to explore when it is uncertain about which solution is correct, which results in superposition. Our experimental results tracking the growth of logits further validate our theory.
