Table of Contents
Fetching ...

Generalized Parallel Scaling with Interdependent Generations

Harry Dong, David Brandfonbrener, Eryk Helenowski, Yun He, Mrinal Kumar, Han Fang, Yuejie Chi, Karthik Abinav Sankararaman

TL;DR

Bridge enables interdependent parallel generation for LLMs by treating the batch of hidden states as a holistic 3-D tensor and introducing lightweight cross-sample attention blocks. Trained with a small parameter budget and a brief SFT warm-up, Bridge shares information across parallel generations for the same prompt, boosting both individual accuracy and set-level quality under RLVR. Across multiple models and math/non-math tasks, Bridge achieves up to ~39% relative gains over RLVR baselines, maintains performance with increased generation width, and improves consistency and coverage in output sets. This approach generalizes the paradigm of parallel scaling beyond independent sampling, offering a practical, scalable path to leveraging inter-sequence information in real-time inference.

Abstract

Parallel LLM inference scaling involves sampling a set of $N>1$ responses for a single input prompt. However, these $N$ parallel responses tend to be generated independently from each other, partitioning compute resources and leaving potentially useful information in one generation untapped by others. This is in contrast to response length scaling where past computation is used in all future steps. For higher quality responses and response sets, we propose Bridge to generate interdependent responses in parallel by rethinking batched LLM hidden states as holistic tensors rather than independent slices. With only a small amount (2.8%-5.1%) of new parameters, Bridge improves the relative mean accuracy gains from reinforcement learning with verifiable rewards by up to 39% and boosts consistency of correct responses. Trained once, Bridge scales to any generation width, all with greater performance than independent generations, unlocking a more general mode of parallel scaling that effectively leverages information between sequences, compatible with any post-generation aggregation technique.

Generalized Parallel Scaling with Interdependent Generations

TL;DR

Bridge enables interdependent parallel generation for LLMs by treating the batch of hidden states as a holistic 3-D tensor and introducing lightweight cross-sample attention blocks. Trained with a small parameter budget and a brief SFT warm-up, Bridge shares information across parallel generations for the same prompt, boosting both individual accuracy and set-level quality under RLVR. Across multiple models and math/non-math tasks, Bridge achieves up to ~39% relative gains over RLVR baselines, maintains performance with increased generation width, and improves consistency and coverage in output sets. This approach generalizes the paradigm of parallel scaling beyond independent sampling, offering a practical, scalable path to leveraging inter-sequence information in real-time inference.

Abstract

Parallel LLM inference scaling involves sampling a set of responses for a single input prompt. However, these parallel responses tend to be generated independently from each other, partitioning compute resources and leaving potentially useful information in one generation untapped by others. This is in contrast to response length scaling where past computation is used in all future steps. For higher quality responses and response sets, we propose Bridge to generate interdependent responses in parallel by rethinking batched LLM hidden states as holistic tensors rather than independent slices. With only a small amount (2.8%-5.1%) of new parameters, Bridge improves the relative mean accuracy gains from reinforcement learning with verifiable rewards by up to 39% and boosts consistency of correct responses. Trained once, Bridge scales to any generation width, all with greater performance than independent generations, unlocking a more general mode of parallel scaling that effectively leverages information between sequences, compatible with any post-generation aggregation technique.

Paper Structure

This paper contains 32 sections, 11 equations, 7 figures, 8 tables, 1 algorithm.

Figures (7)

  • Figure 1: LLM hidden states are 3-D tensors, where attention and feedforward blocks explicitly transfer information between tokens and features, respectively. By instead treating parallel scaling generations as a single tensor rather than independent slices, our method, Bridge, operates along the batch axis, so that tokens from all sequences that share the same prompt can share information throughout generation.
  • Figure 2: Our method design. (Left) A Bridge block and input normalization layer are added after each feedforward block. (Right) A timestep's tokens stemming from the same input prompt attend to each other in Bridge blocks, denoted by the arrows. Dotted arrows illustrate all the locations of information transfer to different sequences in a Markovian fashion (token features only at the current timestep are shared to predict the next timestep's tokens). Attention is masked for tokens from different prompts and from completed generations. White squares are masked cells.
  • Figure 3: Warm up procedure. The original LLM generates candidate traces which are filtered by correctness and compiled into a dataset. SFT on this generated dataset only updates new parameters. The P-Match baseline substitutes Bridge blocks with MLPs matched in parameter count.
  • Figure 4: G-Pass@$8_\tau$ averaged across AIME24, AIME25, AMC23, BRUMO25, CMIMC25, and HMMT_FEB25. Each chart measures the minimum number of correct answers ($\tau \cdot k$) out of $k=8$ simultaneous responses. Bridge has the greatest coverage ($\tau \cdot k = 1$) and answers correctly most consistently ($\tau \cdot k > 1$) in the vast majority of cases. Higher is better.
  • Figure 5: G-Pass@$8_\tau$ improvement upon the original DS-Qwen-7B model averaged across AIME24, AIME25, AMC23, BRUMO25, CMIMC25, and HMMT_FEB25 with relation to the evaluation generation width $w$ of Bridge. The x-axis ($\tau \cdot k$) indicates the number of responses out of $k=8$ that must be correct.
  • ...and 2 more figures