Table of Contents
Fetching ...

Effective Reinforcement Learning for Reasoning in Language Models

Lianghuan Huang, Shuo Li, Sagnik Anupam, Insup Lee, Osbert Bastani

TL;DR

This paper investigates how to design reinforcement learning algorithms to enhance reasoning in language models, with an emphasis on computationally constrained, small to mid-sized models. It compares supervised fine-tuning and on-policy RL, analyzes PPO/GRPO dynamics, and reveals that removing KL regularization can improve generation concision and accuracy. The authors introduce DASH, a preemptive sampling and gradient-filtering algorithm that accelerates on-policy training by up to $83\%$ without sacrificing performance. Across math and coding tasks, DASH demonstrates robust improvements and offers practical guidance for tailoring RL methods to LM reasoning tasks, highlighting how design choices about strategy, updates, advantage estimation, and sampling impact both effectiveness and efficiency.

Abstract

Reinforcement learning (RL) has emerged as a promising strategy for improving the reasoning capabilities of language models (LMs) in domains such as mathematics and coding. However, most modern RL algorithms were designed to target robotics applications, which differ significantly from LM reasoning. We analyze RL algorithm design decisions for LM reasoning, for both accuracy and computational efficiency, focusing on relatively small models due to computational constraints. Our findings are: (i) on-policy RL significantly outperforms supervised fine-tuning (SFT), (ii) PPO-based off-policy updates increase accuracy instead of reduce variance, and (iii) removing KL divergence can lead to more concise generations and higher accuracy. Furthermore, we find that a key bottleneck to computational efficiency is that the optimal batch sizes for inference and backpropagation are different. We propose a novel algorithm, DASH, that performs preemptive sampling (i.e., sample a large batch and accumulate gradient updates in small increments), and gradient filtering (i.e., drop samples with small advantage estimates). We show that DASH reduces training time by 83% compared to a standard implementation of GRPO without sacrificing accuracy. Our findings provide valuable insights on designing effective RL algorithms for LM reasoning.

Effective Reinforcement Learning for Reasoning in Language Models

TL;DR

This paper investigates how to design reinforcement learning algorithms to enhance reasoning in language models, with an emphasis on computationally constrained, small to mid-sized models. It compares supervised fine-tuning and on-policy RL, analyzes PPO/GRPO dynamics, and reveals that removing KL regularization can improve generation concision and accuracy. The authors introduce DASH, a preemptive sampling and gradient-filtering algorithm that accelerates on-policy training by up to without sacrificing performance. Across math and coding tasks, DASH demonstrates robust improvements and offers practical guidance for tailoring RL methods to LM reasoning tasks, highlighting how design choices about strategy, updates, advantage estimation, and sampling impact both effectiveness and efficiency.

Abstract

Reinforcement learning (RL) has emerged as a promising strategy for improving the reasoning capabilities of language models (LMs) in domains such as mathematics and coding. However, most modern RL algorithms were designed to target robotics applications, which differ significantly from LM reasoning. We analyze RL algorithm design decisions for LM reasoning, for both accuracy and computational efficiency, focusing on relatively small models due to computational constraints. Our findings are: (i) on-policy RL significantly outperforms supervised fine-tuning (SFT), (ii) PPO-based off-policy updates increase accuracy instead of reduce variance, and (iii) removing KL divergence can lead to more concise generations and higher accuracy. Furthermore, we find that a key bottleneck to computational efficiency is that the optimal batch sizes for inference and backpropagation are different. We propose a novel algorithm, DASH, that performs preemptive sampling (i.e., sample a large batch and accumulate gradient updates in small increments), and gradient filtering (i.e., drop samples with small advantage estimates). We show that DASH reduces training time by 83% compared to a standard implementation of GRPO without sacrificing accuracy. Our findings provide valuable insights on designing effective RL algorithms for LM reasoning.

Paper Structure

This paper contains 23 sections, 7 equations, 9 figures, 10 tables.

Figures (9)

  • Figure 1: DASH can reduce running time by 83% compared to GRPO by using preemptive sampling (Section \ref{['sec:preemptive_sampling']}) and gradient filtering (Section \ref{['sec:gf']}).
  • Figure 2: Illustration of preemptive sampling. We use $H$ GPUs for inference and $H'$ for backpropagation; they are shown in blue and green, respectively. Given a batch of $M$ prompts $\{\mathbf{x}_1, \ldots, \mathbf{x}_M\}$. The inference GPUs then generate corresponding responses $\{\hat{\mathbf{y}}_1, \ldots, \hat{\mathbf{y}}_M\}$, which are aggregated across GPUs into CPU memory. When a backpropagation GPU requests generations for a prompt $\mathbf{x}_m$, the corresponding cached response $\mathbf{y}_m$ is retrieved and delivered. Since we are using groups for advantage estimation, each prompt $\mathbf{x}_m$ is duplicated to form groups, and all generations in the same group are sent to the backpropagation GPU upon request.
  • Figure 3: Comparison between DASH and No-GF for Qwen2.5-0.5B on math.
  • Figure 4: Training reward curves for PG vs. PPO on Qwen2.5-0.5B for math.
  • Figure 5: Impact of KL divergence regularization on pass@k for Qwen2.5-3B on math.
  • ...and 4 more figures