Table of Contents
Fetching ...

Soft Tokens, Hard Truths

Natasha Butt, Ariel Kwiatkowski, Ismail Labiad, Julia Kempe, Yann Ollivier

TL;DR

The paper tackles the limitations of discrete-chain-of-thought reasoning by introducing a scalable reinforcement-learning approach to train continuous CoTs using soft and fuzzy tokens with input noise. It demonstrates that continuous CoTs can match discrete CoTs at pass@1 and exceed them at pass@32 on math reasoning benchmarks, while enabling robust deployment by later using discrete tokens for inference. The key contributions include a no-ground-truth-CoT RL fine-tuning method, minimal computational overhead, and evidence that continuous CoT training preserves base-model behavior on out-of-domain tasks. The work suggests that continuous reasoning is a practical, scalable alternative for fine-tuning large language models with richer internal reasoning paths.

Abstract

The use of continuous instead of discrete tokens during the Chain-of-Thought (CoT) phase of reasoning LLMs has garnered attention recently, based on the intuition that a continuous mixture of discrete tokens could simulate a superposition of several reasoning paths simultaneously. Theoretical results have formally proven that continuous tokens have much greater expressivity and can solve specific problems more efficiently. However, practical use of continuous tokens has been limited by strong training difficulties: previous works either just use continuous tokens at inference time on a pre-trained discrete-token model, or must distill the continuous CoT from ground-truth discrete CoTs and face computational costs that limit the CoT to very few tokens. This is the first work introducing a scalable method to learn continuous CoTs via reinforcement learning (RL), without distilling from reference discrete CoTs. We use "soft" tokens: mixtures of tokens together with noise on the input embedding to provide RL exploration. Computational overhead is minimal, enabling us to learn continuous CoTs with hundreds of tokens. On math reasoning benchmarks with Llama and Qwen models up to 8B, training with continuous CoTs match discrete-token CoTs for pass@1 and surpass them for pass@32, showing greater CoT diversity. In systematic comparisons, the best-performing scenario is to train with continuous CoT tokens then use discrete tokens for inference, meaning the "soft" models can be deployed in a standard way. Finally, we show continuous CoT RL training better preserves the predictions of the base model on out-of-domain tasks, thus providing a softer touch to the base model.

Soft Tokens, Hard Truths

TL;DR

The paper tackles the limitations of discrete-chain-of-thought reasoning by introducing a scalable reinforcement-learning approach to train continuous CoTs using soft and fuzzy tokens with input noise. It demonstrates that continuous CoTs can match discrete CoTs at pass@1 and exceed them at pass@32 on math reasoning benchmarks, while enabling robust deployment by later using discrete tokens for inference. The key contributions include a no-ground-truth-CoT RL fine-tuning method, minimal computational overhead, and evidence that continuous CoT training preserves base-model behavior on out-of-domain tasks. The work suggests that continuous reasoning is a practical, scalable alternative for fine-tuning large language models with richer internal reasoning paths.

Abstract

The use of continuous instead of discrete tokens during the Chain-of-Thought (CoT) phase of reasoning LLMs has garnered attention recently, based on the intuition that a continuous mixture of discrete tokens could simulate a superposition of several reasoning paths simultaneously. Theoretical results have formally proven that continuous tokens have much greater expressivity and can solve specific problems more efficiently. However, practical use of continuous tokens has been limited by strong training difficulties: previous works either just use continuous tokens at inference time on a pre-trained discrete-token model, or must distill the continuous CoT from ground-truth discrete CoTs and face computational costs that limit the CoT to very few tokens. This is the first work introducing a scalable method to learn continuous CoTs via reinforcement learning (RL), without distilling from reference discrete CoTs. We use "soft" tokens: mixtures of tokens together with noise on the input embedding to provide RL exploration. Computational overhead is minimal, enabling us to learn continuous CoTs with hundreds of tokens. On math reasoning benchmarks with Llama and Qwen models up to 8B, training with continuous CoTs match discrete-token CoTs for pass@1 and surpass them for pass@32, showing greater CoT diversity. In systematic comparisons, the best-performing scenario is to train with continuous CoT tokens then use discrete tokens for inference, meaning the "soft" models can be deployed in a standard way. Finally, we show continuous CoT RL training better preserves the predictions of the base model on out-of-domain tasks, thus providing a softer touch to the base model.

Paper Structure

This paper contains 42 sections, 23 equations, 19 figures, 8 tables.

Figures (19)

  • Figure 1: Hard, fuzzy and soft generation during CoT phase. In hard generation, at each time step, a discrete token $CoT_t$ is sampled from the probability vector $p_{t-1}$ and its embedding $h^0_{CoT_1}$ is passed to the transformer, generating a sequence of discrete CoT tokens: $CoT_1,..., CoT_T$ over time. In fuzzy and soft generation, at each time step, noise, $\epsilon_t$, is injected into the probability weighted mixture embedding, $h_0^t=p_{t-1}E$, where $E$ is the token embedding matrix. This noisy input embedding is passed to the transformer, generating a sequence of continuous noisy CoT embeddings: ${\tilde{h}^0_{CoT_1},..., \tilde{h}^0_{CoT_T}}$ over time. Additionally, for fuzzy generation, the temperature $\tau$ used in the CoT phase tends to 0, such that the non-noisy embeddings $h^0$ reduce to embeddings of discrete tokens. We find that the combination of soft/fuzzy training and hard inference performs universally best, matching hard training at pass@$1$ and surpassing it at pass@$32$, indicating better preservation of diversity.
  • Figure 2: Llama 3b Instruct trained on GSM8K (a) Training performance across steps; one step = two prompts $\times$ 32 samples each. (b) Greedy validation performance used for model selection. For the remaining trained models, see Appendix \ref{['app:supplementary-results:training']}.
  • Figure 3: Hard inference pass@k for Llama models (for soft/fuzzy inference and Qwen see Appendix \ref{['app:supplementary-results:pass@k']}). We observe soft/fuzzy training improves pass@$32$, pointing to preserved diversity. Greedy Pass@1 (the triangles) for all training methods are clustered together.
  • Figure 4: Llama 3b Instruct CoT entropy on GSM8K test set. Fuzzy and soft training preserves entropy profile of base models; we observe a large change in hard sample profile with hard training.
  • Figure 5: Validation performance for: (a) noise scale ablation, (b) temperature ablation, on Llama 3B Instruct trained with fuzzy models on GSM8K. Fuzzy training appears robust to noise scale factors 0.1-1.0 and temperature values 0.1-0.0001.
  • ...and 14 more figures