Robust Policy Optimization to Prevent Catastrophic Forgetting
Mahdi Sabbaghi, George Pappas, Adel Javanmard, Hamed Hassani
TL;DR
The paper addresses catastrophic forgetting during downstream fine-tuning of large language models by proposing Fine-tuning Robust Policy Optimization (FRPO), which embeds robustness into the base policy through a max-min objective over a KL-bounded neighborhood of downstream-adapted policies. FRPO derives an entropic-risk objective parameterized by $\lambda$, integrates with the GRPO framework with no extra computation, and includes a variance-reducing baseline and jackknife bias correction. Empirically, FRPO substantially reduces safety degradation across SFT and RL regimes and base models, and preserves or improves mathematical accuracy after code fine-tuning, achieving up to a 22% gain on MATH500 over GRPO in code-tuning scenarios. The findings suggest robustness to future adaptation can be achieved upstream, with broad implications for alignment and continual learning across domains.
Abstract
Large language models are commonly trained through multi-stage post-training: first via RLHF, then fine-tuned for other downstream objectives. Yet even small downstream updates can compromise earlier learned behaviors (e.g., safety), exposing a brittleness known as catastrophic forgetting. This suggests standard RLHF objectives do not guarantee robustness to future adaptation. To address it, most prior work designs downstream-time methods to preserve previously learned behaviors. We argue that preventing this requires pre-finetuning robustness: the base policy should avoid brittle high-reward solutions whose reward drops sharply under standard fine-tuning. We propose Fine-tuning Robust Policy Optimization (FRPO), a robust RLHF framework that optimizes reward not only at the current policy, but across a KL-bounded neighborhood of policies reachable by downstream adaptation. The key idea is to ensure reward stability under policy shifts via a max-min formulation. By modifying GRPO, we develop an algorithm with no extra computation, and empirically show it substantially reduces safety degradation across multiple base models and downstream fine-tuning regimes (SFT and RL) while preserving downstream task performance. We further study a math-focused RL setting, demonstrating that FRPO preserves accuracy under subsequent fine-tuning.
