Table of Contents
Fetching ...

Thoughtbubbles: an Unsupervised Method for Parallel Thinking in Latent Space

Houjun Liu, Shikhar Murty, Christopher D. Manning, Róbert Csordás

TL;DR

Thoughtbubbles introduces an unsupervised, adaptive latent computation architecture for transformers that forks residual streams to form computation bubbles, controlled by learned scores and merged at the end to produce outputs. The approach is trained purely with language modeling loss and demonstrates superior perplexity and strong zero-shot performance across OpenWebText and peS2o datasets, often matching or exceeding larger baselines while using computation budgets that adapt to input length. Key innovations include a forking mechanism with cumulative scores, residual update attenuation, and a partial RoPE scheme to handle multiple forks per token. The results suggest that latent adaptive computation can be learned during pretraining, enabling more efficient and interpretable reasoning in language models, with limitations in hardware efficiency and the need for broader downstream testing.

Abstract

Current approaches for scaling inference-time compute in transformers rely on training them to emit explicit chain-of-thought tokens before producing an answer. While these methods are powerful, they are limited because they cannot be applied during pretraining and are limited to only serially-generated, natural-language verbalization to scale inference-time compute. In this work, we propose Thoughtbubbles, a transformer variant that natively performs parallel adaptive computation in latent space by learning to fork or delete residual streams. Thus, tokens that require a large amount of computation can form a "bubble" of cloned residuals in the middle of the network for additional thinking. Crucially, this behavior is learned during pretraining with only language modeling loss. Thoughtbubbles outperforms both standard decoder LMs as well as non-adaptive parallel computation approaches on OpenWebText and peS2o perplexity and in zero-shot evaluations such as HellaSwag and LAMBADA after pretraining across 150M to 772M parameter scales. The implicit nature of our method enables adaptive computation to be learned starting at pretraining time, paving the way to unify train and test-time behavior for reasoning models.

Thoughtbubbles: an Unsupervised Method for Parallel Thinking in Latent Space

TL;DR

Thoughtbubbles introduces an unsupervised, adaptive latent computation architecture for transformers that forks residual streams to form computation bubbles, controlled by learned scores and merged at the end to produce outputs. The approach is trained purely with language modeling loss and demonstrates superior perplexity and strong zero-shot performance across OpenWebText and peS2o datasets, often matching or exceeding larger baselines while using computation budgets that adapt to input length. Key innovations include a forking mechanism with cumulative scores, residual update attenuation, and a partial RoPE scheme to handle multiple forks per token. The results suggest that latent adaptive computation can be learned during pretraining, enabling more efficient and interpretable reasoning in language models, with limitations in hardware efficiency and the need for broader downstream testing.

Abstract

Current approaches for scaling inference-time compute in transformers rely on training them to emit explicit chain-of-thought tokens before producing an answer. While these methods are powerful, they are limited because they cannot be applied during pretraining and are limited to only serially-generated, natural-language verbalization to scale inference-time compute. In this work, we propose Thoughtbubbles, a transformer variant that natively performs parallel adaptive computation in latent space by learning to fork or delete residual streams. Thus, tokens that require a large amount of computation can form a "bubble" of cloned residuals in the middle of the network for additional thinking. Crucially, this behavior is learned during pretraining with only language modeling loss. Thoughtbubbles outperforms both standard decoder LMs as well as non-adaptive parallel computation approaches on OpenWebText and peS2o perplexity and in zero-shot evaluations such as HellaSwag and LAMBADA after pretraining across 150M to 772M parameter scales. The implicit nature of our method enables adaptive computation to be learned starting at pretraining time, paving the way to unify train and test-time behavior for reasoning models.

Paper Structure

This paper contains 46 sections, 10 equations, 7 figures, 4 tables.

Figures (7)

  • Figure 1: Overview of our method: input tokens fork to form a bubble of latent computation (orange), which is then contracted to produce the final token. Some extraneous tokens may fork (dark blue), but then be pruned.
  • Figure 2: Forking procedure. Token "is" has two forks, one of which will get deleted; the token "this" creates a new fork; we show a score-attenuated transformer block after a forking operation.
  • Figure 3: Dev-set perplexity of our approach and various baselines as a function of model scale on both OpenWebText and peS2o datasets. Across all scales, note that our method outperforms all baselines, including both computation and parameter-matched ones. Lower is better.
  • Figure 4: Normalized number of forks in the final layer across a window of $4$ tokens as a function of the mean entropy of those $4$ tokens on OpenWebText . Left: entropy as measured by the forking transformer; right: entropy as measured by a baseline decoder LM.
  • Figure 5: Analysis of attention allocation between the main (rightmost, "og") token and its child forks on our approach trained on openwebtext. Note that since we place child token embeddings to the to the left of the main token, forked children cannot attend to its parent.
  • ...and 2 more figures