Table of Contents
Fetching ...

PonderLM-2: Pretraining LLM with Latent Thoughts in Continuous Space

Boyi Zeng, He Li, Shixiang Song, Yixuan Wang, Ziwei He, Xinbing Wang, Zhouhan Lin

TL;DR

PonderLM-2 introduces a horizontal scaling approach for pretraining LLMs by generating per-token latent thoughts in a continuous space, using the last hidden state as input to predict the next token. Training uses Jacobi iterations to parallelize what would otherwise be a sequential process, enabling efficient, scalable pretraining. Empirical results show that models with latent thoughts achieve lower perplexities and stronger performance on general and instruction-following tasks at equal or lower inference costs, often surpassing larger baselines and prior latent-thought methods. The approach generalizes across architectures (Pythia, LLaMA, GPT-2) and can plug into existing foundation-model workflows, suggesting a broader impact on scalable reasoning and downstream task transfer.

Abstract

The remarkable success of Chain-of-Thought (CoT), which enhances performance by scaling generation steps at test-time, inspires us to ask: can we leverage a similar scaling of computational steps during pretraining to improve the generation of each individual token? To address this, we propose a novel pre-training methodology: Pretraining Language Models with Latent Thoughts (PonderLM-2). Our approach pretrains a language model (LM) to first generate an intermediate latent thought-the last hidden state of the current position-which is then used as input to predict the actual subsequent token. This additional computational step enables the LM to refine its prediction within unconstrained continuous space. Our experiments demonstrate that, at an identical inference cost, a LM that generates one additional latent thought per token outperforms a standard model with double the parameters. For instance, our PonderLM-2-Pythia-1.4B, pretrained on 300B tokens from the Pile, significantly surpasses the vanilla Pythia-2.8B trained on the same data on both language modeling and a range of general downstream tasks. Furthermore, increasing the number of latent thoughts generated before each actual token-forming a chain analogous to CoT-consistently improves the model's performance.

PonderLM-2: Pretraining LLM with Latent Thoughts in Continuous Space

TL;DR

PonderLM-2 introduces a horizontal scaling approach for pretraining LLMs by generating per-token latent thoughts in a continuous space, using the last hidden state as input to predict the next token. Training uses Jacobi iterations to parallelize what would otherwise be a sequential process, enabling efficient, scalable pretraining. Empirical results show that models with latent thoughts achieve lower perplexities and stronger performance on general and instruction-following tasks at equal or lower inference costs, often surpassing larger baselines and prior latent-thought methods. The approach generalizes across architectures (Pythia, LLaMA, GPT-2) and can plug into existing foundation-model workflows, suggesting a broader impact on scalable reasoning and downstream task transfer.

Abstract

The remarkable success of Chain-of-Thought (CoT), which enhances performance by scaling generation steps at test-time, inspires us to ask: can we leverage a similar scaling of computational steps during pretraining to improve the generation of each individual token? To address this, we propose a novel pre-training methodology: Pretraining Language Models with Latent Thoughts (PonderLM-2). Our approach pretrains a language model (LM) to first generate an intermediate latent thought-the last hidden state of the current position-which is then used as input to predict the actual subsequent token. This additional computational step enables the LM to refine its prediction within unconstrained continuous space. Our experiments demonstrate that, at an identical inference cost, a LM that generates one additional latent thought per token outperforms a standard model with double the parameters. For instance, our PonderLM-2-Pythia-1.4B, pretrained on 300B tokens from the Pile, significantly surpasses the vanilla Pythia-2.8B trained on the same data on both language modeling and a range of general downstream tasks. Furthermore, increasing the number of latent thoughts generated before each actual token-forming a chain analogous to CoT-consistently improves the model's performance.

Paper Structure

This paper contains 18 sections, 4 equations, 9 figures, 6 tables.

Figures (9)

  • Figure 1: Scaling curves comparing our PonderLM-2-Pythia with the official Pythia suite on the 300B Pile. Our 1.26B model matches the loss of Pythia-2.8B with 55% fewer parameters (left), while our 1.4B model reaches the baseline's final performance with 62% less training data (right).
  • Figure 2: A comparison between the standard language model and our PonderLM-2. In the standard language model, each token is generated after a single forward pass. In contrast, PonderLM-2 does not immediately sample the output token after one forward pass; instead, it uses the computed last hidden state as the next input embedding for generating the subsequent output token. This allows the language model to think in an unconstrained latent space before producing each token.
  • Figure 3: Parallel training procedure of our method (via Jacobi iteration). (1) The model computes initial hidden states from the input embeddings ($x_1, x_2, x_3$). These hidden states are then interleaved with their corresponding token embeddings to form a new input sequence. (2) For K rounds, all hidden states are updated in parallel. In each iteration, hidden states from the previous step are interleaved with the original embeddings to form the new input. (3) Finally, the cross-entropy loss ($\mathcal{L}_1, \mathcal{L}_2, \mathcal{L}_3$) is computed at the positions corresponding to the hidden state inputs to optimize language modeling.
  • Figure 4: MSE of the last hidden states before and after the ith iteration. The model is the vanilla Pythia-1B tested with 4*2048 tokens.
  • Figure 5: Language Modeling Perplexity (PPL). Our method achieves the lowest perplexity, consistently surpassing PonderLM despite its 2$\times$ inference overhead at the same model size. Our PonderLM-2-Pythia-1.4B also outperforms the larger vanilla Pythia-2.8B. Numbers denote the absolute perplexity improvement ($\downarrow$) over the corresponding Pythia models.
  • ...and 4 more figures