RAPID: An Efficient Reinforcement Learning Algorithm for Small Language Models
Lianghuan Huang, Sagnik Anupam, Insup Lee, Shuo Li, Osbert Bastani
TL;DR
RAPID targets the computational bottleneck of reinforcement learning fine-tuning for small language models by separating inference and backpropagation into large-batch and small-batch phases, respectively. It introduces group relative policy gradient (GRPG) combined with an importance-weighted off-policy extension to correct for data mismatch, enabling efficient off-policy updates. Across MBPP+, MATH, and MiniF2F, RAPID achieves 11%–34% faster training with comparable or improved accuracy, and its analysis highlights the trade-offs between inference batch size and gradient stability. This approach has practical impact for resource-constrained LM fine-tuning and could generalize to larger models and broader reasoning tasks.
Abstract
Reinforcement learning (RL) has emerged as a promising strategy for finetuning small language models (SLMs) to solve targeted tasks such as math and coding. However, RL algorithms tend to be resource-intensive, taking a significant amount of time to train. We propose RAPID, a novel RL algorithm that can substantially reduce the running time of RL. Our key insight is that RL tends to be costly due to the need to perform both inference and backpropagation during training. To maximize use of computational resources, our algorithm performs inference in large batches, and then performs off-policy policy gradient updates in mini-batches. For off-policy updates, we incorporate group advantage estimation into the policy gradient algorithm, and derive an importance weighted estimator to correct for the bias arising from off-policy learning. Our experiments demonstrate that our algorithm can reduce running time by 11%-34% on three benchmarks compared to state-of-the-art RL algorithms while maintaining similar or better accuracy.
