Table of Contents
Fetching ...

Reject, Resample, Repeat: Understanding Parallel Reasoning in Language Model Inference

Noah Golowich, Fan Chen, Dhruv Rohatgi, Raghav Singhal, Carles Domingo-Enrich, Dylan J. Foster, Akshay Krishnamurthy

TL;DR

A route to rigorously studyference-time methods that aggregate and prune multiple samples using the lens of particle filtering algorithms such as Sequential Monte Carlo (SMC), and identifies a fundamental limit faced by all particle filtering methods.

Abstract

Inference-time methods that aggregate and prune multiple samples have emerged as a powerful paradigm for steering large language models, yet we lack any principled understanding of their accuracy-cost tradeoffs. In this paper, we introduce a route to rigorously study such approaches using the lens of *particle filtering* algorithms such as Sequential Monte Carlo (SMC). Given a base language model and a *process reward model* estimating expected terminal rewards, we ask: *how accurately can we sample from a target distribution given some number of process reward evaluations?* Theoretically, we identify (1) simple criteria enabling non-asymptotic guarantees for SMC; (2) algorithmic improvements to SMC; and (3) a fundamental limit faced by all particle filtering methods. Empirically, we demonstrate that our theoretical criteria effectively govern the *sampling error* of SMC, though not necessarily its final *accuracy*, suggesting that theoretical perspectives beyond sampling may be necessary.

Reject, Resample, Repeat: Understanding Parallel Reasoning in Language Model Inference

TL;DR

A route to rigorously studyference-time methods that aggregate and prune multiple samples using the lens of particle filtering algorithms such as Sequential Monte Carlo (SMC), and identifies a fundamental limit faced by all particle filtering methods.

Abstract

Inference-time methods that aggregate and prune multiple samples have emerged as a powerful paradigm for steering large language models, yet we lack any principled understanding of their accuracy-cost tradeoffs. In this paper, we introduce a route to rigorously study such approaches using the lens of *particle filtering* algorithms such as Sequential Monte Carlo (SMC). Given a base language model and a *process reward model* estimating expected terminal rewards, we ask: *how accurately can we sample from a target distribution given some number of process reward evaluations?* Theoretically, we identify (1) simple criteria enabling non-asymptotic guarantees for SMC; (2) algorithmic improvements to SMC; and (3) a fundamental limit faced by all particle filtering methods. Empirically, we demonstrate that our theoretical criteria effectively govern the *sampling error* of SMC, though not necessarily its final *accuracy*, suggesting that theoretical perspectives beyond sampling may be necessary.
Paper Structure (70 sections, 28 theorems, 132 equations, 8 figures, 4 algorithms)

This paper contains 70 sections, 28 theorems, 132 equations, 8 figures, 4 algorithms.

Key Result

Theorem 1

Suppose that the following hold. (1) Bounded action-level coverage: for all $h$, $a_{1:h+1}$, $\pi^{\star}(a_{h+1} \mid a_{1:h}) / \pi_{\texttt{ref}}(a_{h+1} \mid a_{1:h}) \leq C_{\sf act}$. (2) Bounded $\chi^2$-divergences (these control the error of $\widehat{V}$ vs. $V^\star$, per fact:chisq-av

Figures (8)

  • Figure 1: Performance of SMC (with $N$ particles) vs. Best-of-$N$ on Math500 problems; here we take $N = 32$. Each point is a different problem; thus, SMC with $N$ particles improves performance over Best-of-$N$ on most problems. See \ref{['sec:smc-math-experiments']} for details.
  • Figure 2: Empirical validation for our theory on the prompt-switching task (see \ref{['sec:experiments']}). (a) We vary the action-level coverage (as measured by a KL-divergence proxy) across many prompts while keeping $\widehat{V} = V^\star$ fixed, and observe that action-level coverage predicts the sampling error of SMC. (b) We fix $\pi_{\texttt{ref}} = \pi^{\star}$ (so that action-level coverage is fixed) and observe that $D_{\mathsf{KL}}(*){\pi^{\star}_h\,\|\,\widehat{\pi}_h}$ predicts the sampling error of SMC.
  • Figure 3: Performance of SMC on the prompt-switching task (\ref{['sec:experiments-particles']}). Each point represents SMC with some number of particles $N \in \{2,4,8,\ldots,128\}$. We average the sampling error and wall-clock time over all data points and all trials per data point. SMC consistently outperforms both sequential importance sampling (SIS) and Best-of-$N$ (BoN) baselines.
  • Figure 4: Influence of PRM error (measured by $D_{\chi^2}(*){\pi^{\star}_h\;\|\;{}\widehat{\pi}_h}$) on SMC accuracy for 10 Math500 problems. Different colors correspond to different values of the inverse temperature parameter $\lambda$ parametrizing $\widehat{V}^{(\lambda)}$. The plots show SMC accuracy with respect to a random particle (not the best one) selected at the end of SMC.
  • Figure 5: Performance of SMC vs. Best-of-$N$ on AIME problems (each point is a different problem). Similar to \ref{['fig:smc-vs-bon']}, the majority of points lie below the line $y=x$, indicating that SMC improves performance over Best-of-$N$ on most problems.
  • ...and 3 more figures

Theorems & Definitions (39)

  • Theorem 1: Corollary of \ref{['thm:smc-chisq']}
  • Remark 2: Connection to literature on refinements of SMC
  • Remark 3: Special case: autoregressive generation
  • Theorem 5
  • Definition 7
  • Theorem 8
  • Theorem 9
  • Theorem 11
  • Theorem 12
  • Proposition 15
  • ...and 29 more