Table of Contents
Fetching ...

Accelerating RL for LLM Reasoning with Optimal Advantage Regression

Kianté Brantley, Mingyu Chen, Zhaolin Gao, Jason D. Lee, Wen Sun, Wenhao Zhan, Xuezhou Zhang

TL;DR

The paper introduces A*-PO, a two-stage KL-regularized RL framework for efficient fine-tuning of LLMs on reasoning tasks. It offline-estimates the optimal value function V* from a reference policy and then performs on-policy least-squares updates using a single generation per prompt, eliminating the need for critics and clipping. Theoretical guarantees show near-optimal performance with polynomial sample complexity, and empirical results across GSM8K, MATH, and long-context benchmarks demonstrate faster training, lower memory usage, and robust reasoning accuracy. This approach offers a scalable, simpler alternative to PPO/GRPO/REBEL for improving LLM reasoning with reduced computational burden.

Abstract

Reinforcement learning (RL) has emerged as a powerful tool for fine-tuning large language models (LLMs) to improve complex reasoning abilities. However, state-of-the-art policy optimization methods often suffer from high computational overhead and memory consumption, primarily due to the need for multiple generations per prompt and the reliance on critic networks or advantage estimates of the current policy. In this paper, we propose $A$*-PO, a novel two-stage policy optimization framework that directly approximates the optimal advantage function and enables efficient training of LLMs for reasoning tasks. In the first stage, we leverage offline sampling from a reference policy to estimate the optimal value function $V$*, eliminating the need for costly online value estimation. In the second stage, we perform on-policy updates using a simple least-squares regression loss with only a single generation per prompt. Theoretically, we establish performance guarantees and prove that the KL-regularized RL objective can be optimized without requiring complex exploration strategies. Empirically, $A$*-PO achieves competitive performance across a wide range of mathematical reasoning benchmarks, while reducing training time by up to 2$\times$ and peak memory usage by over 30% compared to PPO, GRPO, and REBEL. Implementation of $A$*-PO can be found at https://github.com/ZhaolinGao/A-PO.

Accelerating RL for LLM Reasoning with Optimal Advantage Regression

TL;DR

The paper introduces A*-PO, a two-stage KL-regularized RL framework for efficient fine-tuning of LLMs on reasoning tasks. It offline-estimates the optimal value function V* from a reference policy and then performs on-policy least-squares updates using a single generation per prompt, eliminating the need for critics and clipping. Theoretical guarantees show near-optimal performance with polynomial sample complexity, and empirical results across GSM8K, MATH, and long-context benchmarks demonstrate faster training, lower memory usage, and robust reasoning accuracy. This approach offers a scalable, simpler alternative to PPO/GRPO/REBEL for improving LLM reasoning with reduced computational burden.

Abstract

Reinforcement learning (RL) has emerged as a powerful tool for fine-tuning large language models (LLMs) to improve complex reasoning abilities. However, state-of-the-art policy optimization methods often suffer from high computational overhead and memory consumption, primarily due to the need for multiple generations per prompt and the reliance on critic networks or advantage estimates of the current policy. In this paper, we propose *-PO, a novel two-stage policy optimization framework that directly approximates the optimal advantage function and enables efficient training of LLMs for reasoning tasks. In the first stage, we leverage offline sampling from a reference policy to estimate the optimal value function *, eliminating the need for costly online value estimation. In the second stage, we perform on-policy updates using a simple least-squares regression loss with only a single generation per prompt. Theoretically, we establish performance guarantees and prove that the KL-regularized RL objective can be optimized without requiring complex exploration strategies. Empirically, *-PO achieves competitive performance across a wide range of mathematical reasoning benchmarks, while reducing training time by up to 2 and peak memory usage by over 30% compared to PPO, GRPO, and REBEL. Implementation of *-PO can be found at https://github.com/ZhaolinGao/A-PO.

Paper Structure

This paper contains 43 sections, 11 theorems, 89 equations, 11 figures, 4 tables, 1 algorithm.

Key Result

Theorem 1

Suppose that ass:r-realass:l-boundass:vref hold true. With probability at least $1-\delta$, we have

Figures (11)

  • Figure 1: We present $A^\star$-PO, an efficient, regression-based approach for LLM post-training. Prior methods such as GRPO and PPO incur high computational costs, either due to requiring multiple samples per prompt or maintaining an explicit value network. In contrast, $A^\star$-PO simplifies the training process by estimating the optimal value function using offline generations from $\pi_{\mathsf{ref}}$ and requiring only a single response per prompt during online RL. As a result, $A^\star$-PO reduces training time by up to 2$\times$ compared to GRPO and PPO.
  • Figure 2: Test accuracy versus training time, peak memory usage, and KL divergence across four baselines and three model sizes on GSM8K. Our approach (orange) can achieve comparable performance (accuracy) to baselines GRPO and PPO, while being 2x faster, more memory efficient, and achieving a smaller KL divergence. Note that for $A^\star$-PO, the training time includes the time from both stages (i.e., offline data collection from $\pi_{\mathsf{ref}}$ and online RL training).
  • Figure 3: Ablation results with different number of $N$ for estimating $V^{\star}$. Solid lines indicate the moving average with window size 100. (Left) Squared regression loss per step of $A^\star$-PO. (Middle) Training reward per step. (Right) Model performance on MATH500 with varying values of $N$.
  • Figure 4: Ablation Results on Filtering Hard Prompts. (Green) Training time of Stage 2 for different values of $N$. The dashed line indicates the training time without filtering. (Purple) Performance on MATH500 with and without filtering.
  • Figure 5: Ablation results with different $\beta_1$ for estimating $V^{\star}$. Solid lines indicate the moving average with window size 100. (Left) Squared regression loss per step of $A^\star$-PO. (Middle) Training reward per step. (Right) Model performance on MATH500 with varying values of $\beta_1$.
  • ...and 6 more figures

Theorems & Definitions (14)

  • Definition 1
  • Theorem 1
  • Corollary 1: log-linear policies with FTPL, Informal
  • Theorem 2: log-linear policies with OGD, Informal
  • Lemma 1
  • Lemma 2: pmlr-v117-suggala20a
  • Corollary 2: log-linear policies with FTPL
  • proof
  • Lemma 3
  • Theorem 3: log-linear policies with OGD
  • ...and 4 more