Table of Contents
Fetching ...

Robust Policy Optimization to Prevent Catastrophic Forgetting

Mahdi Sabbaghi, George Pappas, Adel Javanmard, Hamed Hassani

TL;DR

The paper addresses catastrophic forgetting during downstream fine-tuning of large language models by proposing Fine-tuning Robust Policy Optimization (FRPO), which embeds robustness into the base policy through a max-min objective over a KL-bounded neighborhood of downstream-adapted policies. FRPO derives an entropic-risk objective parameterized by $\lambda$, integrates with the GRPO framework with no extra computation, and includes a variance-reducing baseline and jackknife bias correction. Empirically, FRPO substantially reduces safety degradation across SFT and RL regimes and base models, and preserves or improves mathematical accuracy after code fine-tuning, achieving up to a 22% gain on MATH500 over GRPO in code-tuning scenarios. The findings suggest robustness to future adaptation can be achieved upstream, with broad implications for alignment and continual learning across domains.

Abstract

Large language models are commonly trained through multi-stage post-training: first via RLHF, then fine-tuned for other downstream objectives. Yet even small downstream updates can compromise earlier learned behaviors (e.g., safety), exposing a brittleness known as catastrophic forgetting. This suggests standard RLHF objectives do not guarantee robustness to future adaptation. To address it, most prior work designs downstream-time methods to preserve previously learned behaviors. We argue that preventing this requires pre-finetuning robustness: the base policy should avoid brittle high-reward solutions whose reward drops sharply under standard fine-tuning. We propose Fine-tuning Robust Policy Optimization (FRPO), a robust RLHF framework that optimizes reward not only at the current policy, but across a KL-bounded neighborhood of policies reachable by downstream adaptation. The key idea is to ensure reward stability under policy shifts via a max-min formulation. By modifying GRPO, we develop an algorithm with no extra computation, and empirically show it substantially reduces safety degradation across multiple base models and downstream fine-tuning regimes (SFT and RL) while preserving downstream task performance. We further study a math-focused RL setting, demonstrating that FRPO preserves accuracy under subsequent fine-tuning.

Robust Policy Optimization to Prevent Catastrophic Forgetting

TL;DR

The paper addresses catastrophic forgetting during downstream fine-tuning of large language models by proposing Fine-tuning Robust Policy Optimization (FRPO), which embeds robustness into the base policy through a max-min objective over a KL-bounded neighborhood of downstream-adapted policies. FRPO derives an entropic-risk objective parameterized by , integrates with the GRPO framework with no extra computation, and includes a variance-reducing baseline and jackknife bias correction. Empirically, FRPO substantially reduces safety degradation across SFT and RL regimes and base models, and preserves or improves mathematical accuracy after code fine-tuning, achieving up to a 22% gain on MATH500 over GRPO in code-tuning scenarios. The findings suggest robustness to future adaptation can be achieved upstream, with broad implications for alignment and continual learning across domains.

Abstract

Large language models are commonly trained through multi-stage post-training: first via RLHF, then fine-tuned for other downstream objectives. Yet even small downstream updates can compromise earlier learned behaviors (e.g., safety), exposing a brittleness known as catastrophic forgetting. This suggests standard RLHF objectives do not guarantee robustness to future adaptation. To address it, most prior work designs downstream-time methods to preserve previously learned behaviors. We argue that preventing this requires pre-finetuning robustness: the base policy should avoid brittle high-reward solutions whose reward drops sharply under standard fine-tuning. We propose Fine-tuning Robust Policy Optimization (FRPO), a robust RLHF framework that optimizes reward not only at the current policy, but across a KL-bounded neighborhood of policies reachable by downstream adaptation. The key idea is to ensure reward stability under policy shifts via a max-min formulation. By modifying GRPO, we develop an algorithm with no extra computation, and empirically show it substantially reduces safety degradation across multiple base models and downstream fine-tuning regimes (SFT and RL) while preserving downstream task performance. We further study a math-focused RL setting, demonstrating that FRPO preserves accuracy under subsequent fine-tuning.
Paper Structure (54 sections, 1 theorem, 29 equations, 13 figures, 2 tables, 1 algorithm)

This paper contains 54 sections, 1 theorem, 29 equations, 13 figures, 2 tables, 1 algorithm.

Key Result

Lemma 3.1

Let $f:\mathbb{R} \!\to\mathbb{R}_+ \cup\{+\infty\}$ be a convex function with $f(1)=0$. Define the likelihood ratio $L(y\mid x):=\frac{Q(y\mid x)}{\pi(y\mid x)}$. Then, the inner problem in eq:primal under the average constraint: admits the following dual form: where $f^*(s):=\sup_{t\ge 0}\{st-f(t)\}$ is the Fenchel conjugate. In addition, if the supremum is finite, it is attained at some $(\la

Figures (13)

  • Figure 1: Illustration of FRPO. Standard RLHF finds high-reward policies that may lie in sharp regions, whereas our method optimizes for reward-flatness within a KL neighborhood, finding policies that maintain high reward after downstream adaptation.
  • Figure 2: (left/middle) The safety reward for Mistral and Qwen as the policy moves from the base by increasing the KL during fine-tuning, when sweeping $\lambda$, and evaluated on a split of the safety prompts; $\lambda = 0.2$ better preserves the safety reward for both models and yields the most flat landscape. (right) KL is the average sequence-level on safety prompts which increases under a constant-lr schedule.
  • Figure 3: Safety evaluation after Alpaca SFT on HarmBench ($\uparrow$ is better) and StrongREJECT ($\downarrow$ is better). (a,d) and (b,e) compare models from the same reference (Mistral and Qwen), showing consistent improvements of our method over GRPO baselines. (c,f) Broader comparison with other safety-focused methods; our approach achieves the highest refusal rates on HarmBench and competitive StrongREJECT scores.
  • Figure 4: Safety metrics during GSM8k SFT for Mistral and Qwen models. Our method maintains higher refusal rates (left, $\uparrow$ is better) and better StrongREJECT scores (right, $\downarrow$ is better) compared to GRPO baselines.
  • Figure 5: (left) Fine-tuning the models on UltraFeedback with GRPO leads to a significant increase in the average response length, inducing more detailed answers to harmful demands. (right) Helpfulness vs. Safety score (1 $-$ StrongREJECT score) for Mistral models after GRPO on UltraFeedback. $\lambda = 0.5$ has better safety score but also lower helpfulness. $\lambda = 2.0$ and GRPO have the higher helpfulness score.
  • ...and 8 more figures

Theorems & Definitions (3)

  • Lemma 3.1: Inner optimization under a general $f$–divergence
  • Remark 3.2
  • Remark 3.3