Table of Contents
Fetching ...

Nested-ReFT: Efficient Reinforcement Learning for Large Language Model Fine-Tuning via Off-Policy Rollouts

Maxime Heuillet, Yufei Cui, Boxing Chen, Audrey Durand, Prasanna Parthasarathi

TL;DR

Reinforced Fine-Tuning (ReFT) of LLMs for mathematical reasoning improves reasoning capabilities but incurs high compute due to generating multiple CoT completions. The authors introduce Nested-ReFT, a framework that uses a nested behavior model created by dynamic layer skipping to perform off-policy rollouts with lower inference cost, coupled with ensemble-based off-policy updates. They provide theoretical guarantees of unbiased gradient estimates with bounded variance and demonstrate practical compute efficiency gains across math benchmarks and model sizes, along with three bias-mitigation variants (including Retrace-λ) to preserve performance. The approach offers a scalable path to more compute-efficient RL-based fine-tuning for large language models, with potential applicability beyond math reasoning and toward other domains requiring verifiable rewards.

Abstract

Advanced reasoning in LLMs on challenging domains like mathematical reasoning can be tackled using verifiable rewards based reinforced fine-tuning (ReFT). In standard ReFT frameworks, a behavior model generates multiple completions with answers per problem, for the answer to be then scored by a reward function. While such RL post-training methods demonstrate significant performance improvements across challenging reasoning domains, the computational cost of generating completions during training with multiple inference steps makes the training cost non-trivial. To address this, we draw inspiration from off-policy RL, and speculative decoding to introduce a novel ReFT framework, dubbed Nested-ReFT, where a subset of layers of the target model acts as the behavior model to generate off-policy completions during training. The behavior model configured with dynamic layer skipping per batch during training decreases the inference cost compared to the standard ReFT frameworks. Our theoretical analysis shows that Nested-ReFT yields unbiased gradient estimates with controlled variance. Our empirical analysis demonstrates improved computational efficiency measured as tokens/sec across multiple math reasoning benchmarks and model sizes. Additionally, we explore three variants of bias mitigation to minimize the off-policyness in the gradient updates that allows for maintaining performance that matches the baseline ReFT performance.

Nested-ReFT: Efficient Reinforcement Learning for Large Language Model Fine-Tuning via Off-Policy Rollouts

TL;DR

Reinforced Fine-Tuning (ReFT) of LLMs for mathematical reasoning improves reasoning capabilities but incurs high compute due to generating multiple CoT completions. The authors introduce Nested-ReFT, a framework that uses a nested behavior model created by dynamic layer skipping to perform off-policy rollouts with lower inference cost, coupled with ensemble-based off-policy updates. They provide theoretical guarantees of unbiased gradient estimates with bounded variance and demonstrate practical compute efficiency gains across math benchmarks and model sizes, along with three bias-mitigation variants (including Retrace-λ) to preserve performance. The approach offers a scalable path to more compute-efficient RL-based fine-tuning for large language models, with potential applicability beyond math reasoning and toward other domains requiring verifiable rewards.

Abstract

Advanced reasoning in LLMs on challenging domains like mathematical reasoning can be tackled using verifiable rewards based reinforced fine-tuning (ReFT). In standard ReFT frameworks, a behavior model generates multiple completions with answers per problem, for the answer to be then scored by a reward function. While such RL post-training methods demonstrate significant performance improvements across challenging reasoning domains, the computational cost of generating completions during training with multiple inference steps makes the training cost non-trivial. To address this, we draw inspiration from off-policy RL, and speculative decoding to introduce a novel ReFT framework, dubbed Nested-ReFT, where a subset of layers of the target model acts as the behavior model to generate off-policy completions during training. The behavior model configured with dynamic layer skipping per batch during training decreases the inference cost compared to the standard ReFT frameworks. Our theoretical analysis shows that Nested-ReFT yields unbiased gradient estimates with controlled variance. Our empirical analysis demonstrates improved computational efficiency measured as tokens/sec across multiple math reasoning benchmarks and model sizes. Additionally, we explore three variants of bias mitigation to minimize the off-policyness in the gradient updates that allows for maintaining performance that matches the baseline ReFT performance.

Paper Structure

This paper contains 36 sections, 1 theorem, 22 equations, 3 figures, 2 tables, 1 algorithm.

Key Result

Theorem 1

(Convergence of Policy Gradient with Ensemble Behavior Policies) Let, With these assumptions, the policy gradient update using an ensemble of behavior policies converges to an optimum off-policy update from the expected advantage function weighted by the mean behavior policy $\bar{\eta}_{\mathcal{Z}}$.

Figures (3)

  • Figure 1: Fine-tuning on SVAMP. Red annotations indicate the smallest value, and Green annotations the largest value.
  • Figure 2: Fine-tuning on GSM8k. Red annotations indicate the smallest value, and Green annotations the largest value.
  • Figure 3: Fine-tuning on Math12k. Red annotations indicate the smallest value, and Green annotations the largest value.

Theorems & Definitions (5)

  • Definition 1: Set of transformer layers indices
  • Definition 2: Set of valid layers indices
  • Definition 3: Layer skipping module
  • Theorem 1
  • proof