Table of Contents
Fetching ...

WARP: On the Benefits of Weight Averaged Rewarded Policies

Alexandre Ramé, Johan Ferret, Nino Vieillard, Robert Dadashi, Léonard Hussenot, Pierre-Louis Cedoz, Pier Giuseppe Sessa, Sertan Girgin, Arthur Douillard, Olivier Bachem

TL;DR

This work addresses the tension in RLHF between maximizing reward and preserving pretraining knowledge by formalizing the $\,\mathrm{KL}\,$-$\$\mathrm{reward}$ Pareto front and introducing Weight Averaged Rewarded Policies (WARP). WARP combines three weight-averaging operations—an EMA anchor for KL regularization, SLERP-based merging of independently fine-tuned policies, and LITI interpolation toward initialization—applied iteratively to progressively refine the frontier. Empirical results on the Gemma 7B RLHF pipeline show that WARP yields higher rewards at fixed KL and outperforms open-source baselines on a range of benchmarks, including mathematics tasks, albeit at higher compute cost due to multiple RL runs per iteration. The approach connects to distributed learning and iterated amplification concepts, offering a scalable post-training alignment technique that preserves knowledge while enhancing alignment quality.

Abstract

Reinforcement learning from human feedback (RLHF) aligns large language models (LLMs) by encouraging their generations to have high rewards, using a reward model trained on human preferences. To prevent the forgetting of pre-trained knowledge, RLHF usually incorporates a KL regularization; this forces the policy to remain close to its supervised fine-tuned initialization, though it hinders the reward optimization. To tackle the trade-off between KL and reward, in this paper we introduce a novel alignment strategy named Weight Averaged Rewarded Policies (WARP). WARP merges policies in the weight space at three distinct stages. First, it uses the exponential moving average of the policy as a dynamic anchor in the KL regularization. Second, it applies spherical interpolation to merge independently fine-tuned policies into a new enhanced one. Third, it linearly interpolates between this merged model and the initialization, to recover features from pre-training. This procedure is then applied iteratively, with each iteration's final model used as an advanced initialization for the next, progressively refining the KL-reward Pareto front, achieving superior rewards at fixed KL. Experiments with GEMMA policies validate that WARP improves their quality and alignment, outperforming other open-source LLMs.

WARP: On the Benefits of Weight Averaged Rewarded Policies

TL;DR

This work addresses the tension in RLHF between maximizing reward and preserving pretraining knowledge by formalizing the -\mathrm{reward}$ Pareto front and introducing Weight Averaged Rewarded Policies (WARP). WARP combines three weight-averaging operations—an EMA anchor for KL regularization, SLERP-based merging of independently fine-tuned policies, and LITI interpolation toward initialization—applied iteratively to progressively refine the frontier. Empirical results on the Gemma 7B RLHF pipeline show that WARP yields higher rewards at fixed KL and outperforms open-source baselines on a range of benchmarks, including mathematics tasks, albeit at higher compute cost due to multiple RL runs per iteration. The approach connects to distributed learning and iterated amplification concepts, offering a scalable post-training alignment technique that preserves knowledge while enhancing alignment quality.

Abstract

Reinforcement learning from human feedback (RLHF) aligns large language models (LLMs) by encouraging their generations to have high rewards, using a reward model trained on human preferences. To prevent the forgetting of pre-trained knowledge, RLHF usually incorporates a KL regularization; this forces the policy to remain close to its supervised fine-tuned initialization, though it hinders the reward optimization. To tackle the trade-off between KL and reward, in this paper we introduce a novel alignment strategy named Weight Averaged Rewarded Policies (WARP). WARP merges policies in the weight space at three distinct stages. First, it uses the exponential moving average of the policy as a dynamic anchor in the KL regularization. Second, it applies spherical interpolation to merge independently fine-tuned policies into a new enhanced one. Third, it linearly interpolates between this merged model and the initialization, to recover features from pre-training. This procedure is then applied iteratively, with each iteration's final model used as an advanced initialization for the next, progressively refining the KL-reward Pareto front, achieving superior rewards at fixed KL. Experiments with GEMMA policies validate that WARP improves their quality and alignment, outperforming other open-source LLMs.

Paper Structure

This paper contains 34 sections, 5 theorems, 25 equations, 17 figures, 2 tables, 1 algorithm.

Key Result

Lemma 1

Under assumption:normalizeddelta, SLERP preserves the norm of the task vector:

Figures (17)

  • Figure 1: \ref{['fig:main:warp']} illustrates the RLHF alignment process with WARP from a supervised fine-tuned (SFT) LLM. WARP uses model merging by weight averaging at three different stages. First, the exponential moving average (EMA) izmailov2018 of the policy serves as the anchor for $\mathrm{KL}$ regularization jaques2017sequence. Second, the independently fine-tuned policies are merged by spherical linear interpolation (SLERP) shoemake1985animating of their task vectors 2022arXiv221204089I. Third, we interpolate towards the initialization (LITI) Wortsman2022robust, revealing a Pareto front of solutions as we slide the interpolating coefficient $\eta$ from $1$ to $0$. This results in the "WARP: 1 iteration" curve from \ref{['fig:main:paretofront']} which improves over the REINFORCE williams1992simple fine-tuning trajectories. Critically, iteratively using a point from this Pareto front as an advanced initialization for the next episode WARP improves performance. Details in \ref{['fig:warp_vs_iterations']}.
  • Figure 2: SLERP vs. LERP.
  • Figure 3: EMA and SLERP experiments. We first compare RL runs with different anchors and strengths $\beta$ in the $\mathrm{KL}$ regularization. We show their results along training in \ref{['fig:contremalp_controlvsstep']}, and their $\mathrm{KL}$-reward Pareto fronts in \ref{['fig:anchor_ema_vs_sft']}. We perform evaluation every $100$ steps, and train them for $T=9k$ steps, though we stopped the trainings if they ever reach a $\mathrm{KL}$ of $200$ (e.g., after $T=1k$ training steps when $\beta=0.0$). \ref{['fig:interpolation8090_controlvslambda']} plots the reward obtained when merging two policies (trained independently after $T$ RL steps with their own EMA anchor) with interpolating coefficient $\lambda$; highest rewards are with SLERP for $\lambda=0.5$ and $T=9k$ steps.
  • Figure 4: LITI and iterative experiments.\ref{['fig:num_steps']} considers the LITI of the SLERP of $M=2$ policies after $T$ steps with $\lambda=0.5$, interpolating towards their SFT init as we slide $\eta$, revealing Pareto fronts above the $M=2$ REINFORCE training trajectories. Then \ref{['fig:impact_m']} plots the LITI of the SLERP of $M$ weights with $\lambda=\frac{1}{M}$ after $T=9k$ steps: light-colored areas show standard deviations across $5$ experiments. The iterative WARP procedure is illustrated in \ref{['fig:warp_vs_iterations']}; we fine-tune $M=2$ policies with their own EMA as the anchor, merge them with SLERP, interpolate towards their init with LITI, and iteratively leverage the weights obtained with $\eta=0.3$ as the new initialization for the next iteration.
  • Figure 5: Detailed illustration of the WARP strategy. From a (pre-trained and supervised fine-tuned) LLM $\theta_\mathrm{init}$, we launch $M=2$ fine-tunings (black arrows ). The innovation of WARP lies in the use of model merging by weight averaging at three different stages. First, the exponential moving averages (EMA, blue dashed arrows ) of the policy (collected at different training steps) serves as the anchor for the $\mathrm{KL}$ regularization (black double-headed dotted arrows ). The fine-tuned networks are weight averaged using spherical linear interpolation of task vectors (SLERP, yellow dashed arrows ). Third, we interpolate towards the initialization (LITI, red dashed arrows ). This obtained model $\theta_\mathrm{init}^{\prime}$ serves as an updated initialization for the next iteration, progressively refining the model’s capabilities and alignment. Overall, the final model $\theta_\mathrm{slerp}^{\prime}$ has high reward but also high $\mathrm{KL}$. Then, by interpolation towards the SFT init, we reveal a $\mathrm{KL}$-reward Pareto front of solutions: $\{(1-\eta) \cdot \theta_\mathrm{sft} + \eta \cdot \theta_\mathrm{slerp}^{I}\mid0\leq\eta\leq1\}$.
  • ...and 12 more figures

Theorems & Definitions (11)

  • Lemma 1: SLERP task vector
  • proof
  • Lemma 2: LERP task vector
  • proof
  • Lemma 3: LERP reduces $\mathrm{KL}$
  • proof
  • Remark 1
  • Lemma 4: $\mathrm{KL}$ upper bound for interpolated distributions
  • proof
  • Lemma 5: LITI Pareto optimality
  • ...and 1 more