Table of Contents
Fetching ...

Not All Rollouts are Useful: Down-Sampling Rollouts in LLM Reinforcement Learning

Yixuan Even Xu, Yash Savani, Fei Fang, J. Zico Kolter

TL;DR

This paper identifies a fundamental bottleneck in RLVR for LLMs: rollout generation scales well, but policy updates are memory- and communication-bound. It proposes PODS, which generates many rollouts per prompt but trains on a carefully selected subset of size $m$ using a max-variance down-sampling criterion, achieving significant speedups without sacrificing performance. The method is validated across multiple model sizes (3B–7B), architectures (Qwen2.5, Llama3.2), and hardware configurations on GSM8K and MATH, with GRPO-PODS reaching peak baselines at least 1.7× faster and often better final accuracy. The work offers a practical, algorithm-agnostic approach to accelerate RLVR training and opens avenues for integrating PODS with other RL methods and adaptive sampling strategies.

Abstract

Reinforcement learning with verifiable rewards (RLVR) has emerged as the leading approach for enhancing reasoning capabilities in large language models. However, it faces a fundamental compute and memory asymmetry: rollout generation is embarrassingly parallel and memory-light, whereas policy updates are communication-heavy and memory-intensive. To address this, we introduce PODS (Policy Optimization with Down-Sampling), which decouples rollout generation from policy updates by training only on a strategically selected subset of rollouts, maintaining learning quality while dramatically reducing update costs. We propose a principled subset selection criterion, max-variance down-sampling, that maximizes reward diversity, and provide an efficient $O(n\log n)$ implementation. Empirically, Group Relative Policy Optimization (GRPO) with PODS achieves the peak test accuracy of vanilla GRPO at least $\mathbf{1.7\times}$ faster across the different reasoning benchmarks and hardware configurations we tested.

Not All Rollouts are Useful: Down-Sampling Rollouts in LLM Reinforcement Learning

TL;DR

This paper identifies a fundamental bottleneck in RLVR for LLMs: rollout generation scales well, but policy updates are memory- and communication-bound. It proposes PODS, which generates many rollouts per prompt but trains on a carefully selected subset of size using a max-variance down-sampling criterion, achieving significant speedups without sacrificing performance. The method is validated across multiple model sizes (3B–7B), architectures (Qwen2.5, Llama3.2), and hardware configurations on GSM8K and MATH, with GRPO-PODS reaching peak baselines at least 1.7× faster and often better final accuracy. The work offers a practical, algorithm-agnostic approach to accelerate RLVR training and opens avenues for integrating PODS with other RL methods and adaptive sampling strategies.

Abstract

Reinforcement learning with verifiable rewards (RLVR) has emerged as the leading approach for enhancing reasoning capabilities in large language models. However, it faces a fundamental compute and memory asymmetry: rollout generation is embarrassingly parallel and memory-light, whereas policy updates are communication-heavy and memory-intensive. To address this, we introduce PODS (Policy Optimization with Down-Sampling), which decouples rollout generation from policy updates by training only on a strategically selected subset of rollouts, maintaining learning quality while dramatically reducing update costs. We propose a principled subset selection criterion, max-variance down-sampling, that maximizes reward diversity, and provide an efficient implementation. Empirically, Group Relative Policy Optimization (GRPO) with PODS achieves the peak test accuracy of vanilla GRPO at least faster across the different reasoning benchmarks and hardware configurations we tested.

Paper Structure

This paper contains 29 sections, 3 theorems, 4 equations, 8 figures, 3 tables, 2 algorithms.

Key Result

Lemma 3.1

For a sorted list of rewards $r_1 \leq r_2 \leq \cdots \leq r_n$, the variance-maximizing subset of size $m$ always consists of the $k$ highest rewards and $(m-k)$ lowest rewards for some $k \in \{0,1,\ldots,m\}$. That is,

Figures (8)

  • Figure 1: Inference scales efficiently while policy updates become memory-bound in RLVR. Empirical timing breakdown when fine-tuning Qwen2.5-3B-Instruct on GSM8K using $8$ A100-80GB GPUs with varying rollouts per GPU. Top: Total wall-clock time per iteration. Policy updates hit memory limits after $32$ rollouts per GPU (OOM beyond this point), requiring gradient accumulation that dramatically slows training. Bottom: Per-token inference time decreases $21\times$ through batching (from $8$ to $512$ rollouts), saturating beyond $512$. This demonstrates the core asymmetry that PODS exploits: inference parallelizes efficiently while policy updates become memory-bound.
  • Figure 2: Visualization of three training strategies: vanilla GRPO, GRPO with gradient accumulation (GRPO-GA), and GRPO with PODS (GRPO-PODS). Vanilla GRPO generates $n$ rollouts and trains on all of them, leaving inference hardware underutilized. GRPO-GA alleviates this issue with memory-saving techniques such as gradient accumulation, but at the cost of more sequential steps in the policy-update phase. In contrast, GRPO-PODS also generates $n$ rollouts but trains on only $m$ carefully selected examples, maximizing inference utilization, avoiding gradient-accumulation overhead, and providing a cleaner learning signal that yields better final performance.
  • Figure 3: Performance and per-step run time comparison of standard GRPO and GRPO-PODS with max-variance down-sampling across different datasets and hardware environments. For the performance comparison, the x-axis shows the training time, and the y-axis shows the accuracy on the test set. The shaded area represents 1.96 times the standard error of the mean.
  • Figure 4: Performance and per-step run time comparison of GRPO-PODS with max-variance down-sampling across different settings of $n$ and $m$. The training is conducted on the GSM8K dataset with one L40S. For the performance comparison, the x-axis shows the training time, and the y-axis shows the accuracy on the test set. The shaded area represents 1.96 times the standard error of the mean.
  • Figure 5: Performance and per-step run time comparison of GRPO-PODS with the max-variance, max-reward and random down-sampling rules. The training is conducted on the GSM8K dataset with one L40S. For the performance comparison, the x-axis shows the training time, and the y-axis shows the accuracy on the test set. The shaded area represents 1.96 times the standard error of the mean.
  • ...and 3 more figures

Theorems & Definitions (4)

  • Definition 3.1: Down-sampling rule
  • Lemma 3.1
  • Theorem 1
  • Theorem 2