Table of Contents
Fetching ...

Continuous Chain of Thought Enables Parallel Exploration and Reasoning

Halil Alperen Gozeten, M. Emrullah Ildiz, Xuechen Zhang, Hrayr Harutyunyan, Ankit Singh Rawat, Samet Oymak

TL;DR

<3-5 sentence high-level summary> Continuous Chain-of-Thought (CoT2) replaces single-token autoregressive sampling with continuous, simplex-weighted token mixtures to enable parallel exploration of multiple reasoning traces. The paper introduces CSFT, a supervised training framework that injects soft, budgeted targets for intermediate steps and preserves a discrete final decision, and develops RL strategies (GRPO with Multi-Token and Dirichlet sampling) to optimize continuous traces. Theoretical contributions include an embedding-capacity bound and a one-layer transformer construction that solves the MNNS task, plus formal comparisons between CoT2, CoT, and CoT2-MTS. Empirically, CoT2 with CSFT and RL outperforms discrete CoT and COCONUT on MNNS, ProntoQA, and ProsQA, with clear tradeoffs between embedding size and parallelism and improved inference efficiency through parallel reasoning.

Abstract

Modern language models generate chain-of-thought traces by autoregressively sampling tokens from a finite vocabulary. While this discrete sampling has achieved remarkable success, conducting chain-of-thought with continuously-valued tokens (CoT2) offers a richer and more expressive alternative. Our work provides new theoretical guarantees and algorithms for CoT2, motivated by logical reasoning tasks that inherently require search capabilities. Theoretically, we establish how CoT2 facilitates the model to track multiple discrete traces in parallel; and quantify the level of achievable parallelism and its benefits for inference efficiency. We also provide a CoT2-based one-layer transformer construction that solves the combinatorial "subset sum problem" given a sufficient embedding dimension. These insights arise from a novel and effective supervision strategy where we match the language model outputs to the empirical token distributions of a set of target traces. Complementing this, we introduce sampling strategies that unlock policy optimization methods for CoT2. Our primary strategy samples and composes $K$ discrete tokens at each decoding step to control the level of parallelism. Experiments confirm that (i) the optimal level of parallelism is governed by the embedding dimension, (ii) our continuous supervision strategy can outperform alternative methods, and (iii) policy optimization with CoT2 indeed improves the performance of the model beyond its initial discrete or continuous supervision.

Continuous Chain of Thought Enables Parallel Exploration and Reasoning

TL;DR

<3-5 sentence high-level summary> Continuous Chain-of-Thought (CoT2) replaces single-token autoregressive sampling with continuous, simplex-weighted token mixtures to enable parallel exploration of multiple reasoning traces. The paper introduces CSFT, a supervised training framework that injects soft, budgeted targets for intermediate steps and preserves a discrete final decision, and develops RL strategies (GRPO with Multi-Token and Dirichlet sampling) to optimize continuous traces. Theoretical contributions include an embedding-capacity bound and a one-layer transformer construction that solves the MNNS task, plus formal comparisons between CoT2, CoT, and CoT2-MTS. Empirically, CoT2 with CSFT and RL outperforms discrete CoT and COCONUT on MNNS, ProntoQA, and ProsQA, with clear tradeoffs between embedding size and parallelism and improved inference efficiency through parallel reasoning.

Abstract

Modern language models generate chain-of-thought traces by autoregressively sampling tokens from a finite vocabulary. While this discrete sampling has achieved remarkable success, conducting chain-of-thought with continuously-valued tokens (CoT2) offers a richer and more expressive alternative. Our work provides new theoretical guarantees and algorithms for CoT2, motivated by logical reasoning tasks that inherently require search capabilities. Theoretically, we establish how CoT2 facilitates the model to track multiple discrete traces in parallel; and quantify the level of achievable parallelism and its benefits for inference efficiency. We also provide a CoT2-based one-layer transformer construction that solves the combinatorial "subset sum problem" given a sufficient embedding dimension. These insights arise from a novel and effective supervision strategy where we match the language model outputs to the empirical token distributions of a set of target traces. Complementing this, we introduce sampling strategies that unlock policy optimization methods for CoT2. Our primary strategy samples and composes discrete tokens at each decoding step to control the level of parallelism. Experiments confirm that (i) the optimal level of parallelism is governed by the embedding dimension, (ii) our continuous supervision strategy can outperform alternative methods, and (iii) policy optimization with CoT2 indeed improves the performance of the model beyond its initial discrete or continuous supervision.

Paper Structure

This paper contains 28 sections, 10 theorems, 85 equations, 10 figures, 5 tables, 1 algorithm.

Key Result

Proposition 1

There exists a $1$-layer transformer architecture that solves the MNNS task using CoT2 by storing (sine, cosine) embeddings of all $2^k$ states at the $k$-th iteration in a non-overlapping manner.

Figures (10)

  • Figure 1: Illustration of CoT2 with varying budgets $B$ for Minimum Non-Negative Sum (MNNS) task with $m=3$ and input numbers $2,1,4$. CoT2 supervision with budget $B$ at steps $t\in\{1,\dots,m-1\}$ is the average of embeddings of states visited by $B$ selected trajectories among the 8 possible, and for $t=m$ is the embedding corresponding to the answer. For $B=1$ (discrete CoT), the correct trajectory $(-2,\,-3,\,1)$ highlighted with yellow is used; for $B=2$, the red and yellow trajectories are used; for $B=8$, all trajectories are included in supervision.
  • Figure 2: (a): Discrete CoT model requires multiple samplings (Pass@k) to match the single‐shot performance of CoT2 model on MNNS (10‐run avgs). (b): CoT2 model outperforms COCONUT, discrete CoT, and no-CoT in tasks involving search, like MNNS and ProsQA (5-run avgs). (c): Tradeoff between the number of trajectories superposed and the embedding dimension (5-run avgs). Setting: MNNS with 4 input digits in $1$–$9$. In (a-b), $B$ is the full budget for CoT2, and $B=1$ for discrete CoT. (a): 1‐layer, 1‐head GPT2 with $d=24$. (b): MNNS: 2-layer 2-head GPT2, $d=32$; ProsQA: 4-layer 4-head GPT2, $d=32$. (c): 2‐layer, 2‐head GPT2 with $d \in \{16, 24,32\}$.
  • Figure 3: Training performance vs. embedding dimension for CoT2 ($B=16$) and discrete CoT ($B=1$) on MNNS with 4 input digits from 1-9 and 2-layer, 2-head GPT2 with $d \in \{16, 24,32\}$.
  • Figure 4: The figure illustrates that when the range of digits makes the question non-trivial on an MNNS task, the discrete CoT model trained with full token supervision outperforms sparse supervisions; in particular, single token supervision yields the worst performance. Setting: 5 input digits in $5-13$; 2‐layer, 2‐head GPT2 with $d=32$.
  • Figure 5: The figure reveals that CoT2 model is superior to discrete CoT in ProsQA, while also exhibiting faster convergence. Setting: 4‐layer, 4‐head GPT2 with $d \in \{24,32, 40\}$.
  • ...and 5 more figures

Theorems & Definitions (18)

  • Proposition 1: Solving MNNS
  • Proposition 2: Consistency of CoT and CoT2 inference
  • Proposition 3
  • Proposition 3: Consistency of CoT and CoT2 inference
  • proof
  • Proposition 3
  • proof
  • Remark 1
  • Proposition 4
  • proof
  • ...and 8 more