Table of Contents
Fetching ...

RL-finetuning LLMs from on- and off-policy data with a single algorithm

Yunhao Tang, Taco Cohen, David W. Zhang, Michal Valko, Rémi Munos

TL;DR

This work introduces Any-Generation Reward Optimization (AGRO), a unified RLHF-fine-tuning algorithm for LLMs that leverages generation consistency to enable learning from both on-policy and off-policy data. It derives variance-based loss functions from the consistency condition and provides gradient decompositions that include pathwise and likelihood-ratio components, ensuring convergence to the optimal policy $\pi^*$. The authors propose off-policy and on-policy AGRO variants, with token-level implementations and variance-reduction techniques, and demonstrate competitive gains on a mathematics reasoning benchmark (MATH) using an 8B Llama-3 model. They also compare against KL-regularized policy gradient, showing AGRO's superior convergence properties and KL-efficiency in off-policy settings, while discussing limitations and future work on stability and importance sampling for broader applicability.

Abstract

We introduce a novel reinforcement learning algorithm (AGRO, for Any-Generation Reward Optimization) for fine-tuning large-language models. AGRO leverages the concept of generation consistency, which states that the optimal policy satisfies the notion of consistency across any possible generation of the model. We derive algorithms that find optimal solutions via the sample-based policy gradient and provide theoretical guarantees on their convergence. Our experiments demonstrate the effectiveness of AGRO in both on-policy and off-policy settings, showing improved performance on the mathematical reasoning dataset over baseline algorithms.

RL-finetuning LLMs from on- and off-policy data with a single algorithm

TL;DR

This work introduces Any-Generation Reward Optimization (AGRO), a unified RLHF-fine-tuning algorithm for LLMs that leverages generation consistency to enable learning from both on-policy and off-policy data. It derives variance-based loss functions from the consistency condition and provides gradient decompositions that include pathwise and likelihood-ratio components, ensuring convergence to the optimal policy . The authors propose off-policy and on-policy AGRO variants, with token-level implementations and variance-reduction techniques, and demonstrate competitive gains on a mathematics reasoning benchmark (MATH) using an 8B Llama-3 model. They also compare against KL-regularized policy gradient, showing AGRO's superior convergence properties and KL-efficiency in off-policy settings, while discussing limitations and future work on stability and importance sampling for broader applicability.

Abstract

We introduce a novel reinforcement learning algorithm (AGRO, for Any-Generation Reward Optimization) for fine-tuning large-language models. AGRO leverages the concept of generation consistency, which states that the optimal policy satisfies the notion of consistency across any possible generation of the model. We derive algorithms that find optimal solutions via the sample-based policy gradient and provide theoretical guarantees on their convergence. Our experiments demonstrate the effectiveness of AGRO in both on-policy and off-policy settings, showing improved performance on the mathematical reasoning dataset over baseline algorithms.

Paper Structure

This paper contains 41 sections, 5 theorems, 45 equations, 6 figures.

Key Result

Theorem 1

(Generation consistency at optimality) Let $\pi^* \stackrel{\small{\mathsf{def}}}{=} \arg\max_\pi {\cal G}(\pi)$ be the optimal policy to the RLHF problem defined by Eq. eq:regularized.objective. Then for any generation $y$, the following quantity computed per $(x,y)$, does not depend on $y$. It is a function of the prompt $x$ only.

Figures (6)

  • Figure 1: KL-regularized policy gradient vs. AGRO in the tabular case with off-policy data (data generated from $\pi_{\text{ref}}$ rather than $\pi$). We measure the normalized KL-divergence $\mathbb{KL}(\pi,\pi^*)$ between $\pi$ and the optimal policy $\pi^*$ during training. We see that under off-policy data, regularized KL-regularized policy gradient does not converge to the optimal policy $\pi^*$ while AGRO converges as evidenced by the vanishing KL divergence.
  • Figure 2: Training performance for on-policy learning. We compare three algorithmic alternatives: regularized policy gradient algorithm (red), off-policy AGRO algorithm with on-policy sampling (blue) and on-policy AGRO algorithm (green), all with regularization $\beta=0.001$. The performance is evaluated against the learning iterations, with each iteration being $100$ gradient updates. We observe a similar performance from the regularized policy gradient algorithm and off-policy AGRO, which is expected; on-policy AGRO seems to derive better performance over other baselines, given the same number of learning steps.
  • Figure 3: Evaluation on the test set for the training experiments conducted in Figure \ref{['figure:onpolicy-steps']}. We plot the evaluation performance as a function of the learning iterations. We see that the training and evaluation performance are quite correlated, and that on-policy AGRO seems to achieve the best performance across different algorithmic variants.
  • Figure 4: Comparison of on--policy vs off-policy algorithms. Within the same number of updates, on-policy algorithms generally deviate more from the reference policy $\pi_{\text{ref}}$, leading to larger increase in training performance. However, the KL-performance curve traced out by the off-policy algorithms aligns with the on-policy algorithms, indicating that we do not suffer any KL inefficiency despite being off-policy.
  • Figure 5: Evaluation performance for off-policy (offline) learning. We compare two algorithms: KL-regularized policy gradient (red) and off-policy AGRO (blue). We see that as training progresses, AGRO seems to obtain better overall performance than KL-regularized policy gradient.
  • ...and 1 more figures

Theorems & Definitions (9)

  • Theorem 1
  • Theorem 2
  • proof
  • Lemma 3
  • Proposition 4
  • proof
  • proof
  • Lemma 5
  • proof