Table of Contents
Fetching ...

Learning to Reason Across Parallel Samples for LLM Reasoning

Jianing Qi, Xi Ye, Hao Tang, Zhigang Zhu, Eunsol Choi

TL;DR

This work tackles improving reasoning in LLMs at test time by aggregating multiple parallel samples from a base model using a compact SSA. SSA is a small LLM trained with reinforcement learning to produce a final answer from a concatenated set of candidate solutions, effectively decoupling answer generation from aggregation and enabling use with black-box bases. Across GSM8K, MATH, AIME, AMC, and Olympiad benchmarks, SSA yields strong gains, including a 3B SSA that matches or exceeds larger re-ranking baselines and reduces the gap to oracle pass@K; it also generalizes across model scales and families. The authors also introduce a scalable two-stage SSA for large candidate sets and provide analyses on training methods and reasoning token usage, highlighting practical, plug-and-play benefits for test-time reasoning.

Abstract

Scaling test-time compute brings substantial performance gains for large language models (LLMs). By sampling multiple answers and heuristically aggregate their answers (e.g., either through majority voting or using verifiers to rank the answers), one can achieve consistent performance gains in math domains. In this paper, we propose a new way to leverage such multiple sample set. We train a compact LLM, called Sample Set Aggregator (SSA), that takes a concatenated sequence of multiple samples and output the final answer, optimizing it for the answer accuracy with reinforcement learning. Experiments on five reasoning datasets demonstrate both the efficacy and efficiency of SSA. Notably, SSA improves over naive majority voting by 8% pass@5 on MATH. Furthermore, our 3B SSA surpasses model-based re-ranking with a much larger 72B process reward model. Our analysis also shows promising generalization ability of SSA, across sample set sizes, base model families and scales, and tasks. By separating LLMs to generate answers and LLMs to analyze and aggregate sampled answers, our approach can work with the outputs from premier black box models easily and efficiently.

Learning to Reason Across Parallel Samples for LLM Reasoning

TL;DR

This work tackles improving reasoning in LLMs at test time by aggregating multiple parallel samples from a base model using a compact SSA. SSA is a small LLM trained with reinforcement learning to produce a final answer from a concatenated set of candidate solutions, effectively decoupling answer generation from aggregation and enabling use with black-box bases. Across GSM8K, MATH, AIME, AMC, and Olympiad benchmarks, SSA yields strong gains, including a 3B SSA that matches or exceeds larger re-ranking baselines and reduces the gap to oracle pass@K; it also generalizes across model scales and families. The authors also introduce a scalable two-stage SSA for large candidate sets and provide analyses on training methods and reasoning token usage, highlighting practical, plug-and-play benefits for test-time reasoning.

Abstract

Scaling test-time compute brings substantial performance gains for large language models (LLMs). By sampling multiple answers and heuristically aggregate their answers (e.g., either through majority voting or using verifiers to rank the answers), one can achieve consistent performance gains in math domains. In this paper, we propose a new way to leverage such multiple sample set. We train a compact LLM, called Sample Set Aggregator (SSA), that takes a concatenated sequence of multiple samples and output the final answer, optimizing it for the answer accuracy with reinforcement learning. Experiments on five reasoning datasets demonstrate both the efficacy and efficiency of SSA. Notably, SSA improves over naive majority voting by 8% pass@5 on MATH. Furthermore, our 3B SSA surpasses model-based re-ranking with a much larger 72B process reward model. Our analysis also shows promising generalization ability of SSA, across sample set sizes, base model families and scales, and tasks. By separating LLMs to generate answers and LLMs to analyze and aggregate sampled answers, our approach can work with the outputs from premier black box models easily and efficiently.

Paper Structure

This paper contains 41 sections, 5 equations, 17 figures, 13 tables, 1 algorithm.

Figures (17)

  • Figure 1: Illustration of our approach (bottom), parallel method (top left), and sequential method (top right). We train a compact LLM, called Sample Set Aggregator (SSA), to take a concatenated sequence of multiple samples and output the final answer.
  • Figure 2: Compare the performance of SSA RL, PRM, and Majority Vote methods across Qwen 2.5 $\text{LLM}_{\text{ans}}$ model sizes (7B, 14B, 32B) and number of candidate solutions $k = 5, 10, 15$.
  • Figure 3: Performance comparison between sequential scaling with RL and our SSA
  • Figure 4: Training method performance and response length analysis. (a) Average accuracy across datasets shows RL method is more generalizable than SFT method, with performance improving for larger models. (b) Response length trends during training show a rapid decrease of output length.
  • Figure 5: Compare the performance of model based on Qwen 2.5 7B with $k=5$. SSAs are in green. We see SSA method is very effective against baseline methods.
  • ...and 12 more figures