Table of Contents
Fetching ...

Decoding-Time Language Model Alignment with Multiple Objectives

Ruizhe Shi, Yifang Chen, Yushi Hu, Alisa Liu, Hannaneh Hajishirzi, Noah A. Smith, Simon S. Du

TL;DR

This work proposes a decoding-time algorithm that outputs the next token from a linear combination of predictions of all base models, for any given weightings over different objectives, and experiment with preference weightings using MOD to find the best combination of models.

Abstract

Aligning language models (LMs) to human preferences has emerged as a critical pursuit, enabling these models to better serve diverse user needs. Existing methods primarily focus on optimizing LMs for a single reward function, limiting their adaptability to varied objectives. Here, we propose $\textbf{multi-objective decoding (MOD)}$, a decoding-time algorithm that outputs the next token from a linear combination of predictions of all base models, for any given weightings over different objectives. We exploit a common form among a family of $f$-divergence regularized alignment approaches (such as PPO, DPO, and their variants) to identify a closed-form solution by Legendre transform, and derive an efficient decoding strategy. Theoretically, we show why existing approaches can be sub-optimal even in natural settings and obtain optimality guarantees for our method. Empirical results demonstrate the effectiveness of the algorithm. For example, compared to a parameter-merging baseline, MOD achieves 12.8% overall reward improvement when equally optimizing towards $3$ objectives. Moreover, we experiment with MOD on combining three fully-finetuned LLMs of different model sizes, each aimed at different objectives such as safety, coding, and general user preference. Unlike traditional methods that require careful curation of a mixture of datasets to achieve comprehensive improvement, we can quickly experiment with preference weightings using MOD to find the best combination of models. Our best combination reduces toxicity on Toxigen to nearly 0% and achieves 7.9--33.3% improvement across other three metrics ($\textit{i.e.}$, Codex@1, GSM-COT, BBH-COT).

Decoding-Time Language Model Alignment with Multiple Objectives

TL;DR

This work proposes a decoding-time algorithm that outputs the next token from a linear combination of predictions of all base models, for any given weightings over different objectives, and experiment with preference weightings using MOD to find the best combination of models.

Abstract

Aligning language models (LMs) to human preferences has emerged as a critical pursuit, enabling these models to better serve diverse user needs. Existing methods primarily focus on optimizing LMs for a single reward function, limiting their adaptability to varied objectives. Here, we propose , a decoding-time algorithm that outputs the next token from a linear combination of predictions of all base models, for any given weightings over different objectives. We exploit a common form among a family of -divergence regularized alignment approaches (such as PPO, DPO, and their variants) to identify a closed-form solution by Legendre transform, and derive an efficient decoding strategy. Theoretically, we show why existing approaches can be sub-optimal even in natural settings and obtain optimality guarantees for our method. Empirical results demonstrate the effectiveness of the algorithm. For example, compared to a parameter-merging baseline, MOD achieves 12.8% overall reward improvement when equally optimizing towards objectives. Moreover, we experiment with MOD on combining three fully-finetuned LLMs of different model sizes, each aimed at different objectives such as safety, coding, and general user preference. Unlike traditional methods that require careful curation of a mixture of datasets to achieve comprehensive improvement, we can quickly experiment with preference weightings using MOD to find the best combination of models. Our best combination reduces toxicity on Toxigen to nearly 0% and achieves 7.9--33.3% improvement across other three metrics (, Codex@1, GSM-COT, BBH-COT).
Paper Structure (36 sections, 15 theorems, 82 equations, 6 figures, 23 tables, 1 algorithm)

This paper contains 36 sections, 15 theorems, 82 equations, 6 figures, 23 tables, 1 algorithm.

Key Result

Theorem 1

There exists a certain $C_2$ such that: is the optimal solution for this revised optimization problem eq: main opt.

Figures (6)

  • Figure 1: Multi-objective decoding. We prepare LMs tuned for each objective in advance. Then, given preference weightings $w$, input prompt $x$ and context $y_{<t}$, $y_t$ is greedily decoded from an algebraic combination of predicted probabilities from each LM, achieving precise control.
  • Figure 2: Reddit Summary. The frontier of MOD generally lies over RS and MORLHF.
  • Figure 3: Helpful Assistant. MOD prominently beats RS for each reward pair. When balancing between harmlessness and humor, MOD lags behind the more expensive MORLHF.
  • Figure 4: Safety Alignment. Figures from left to right illustrate $f$-DPO models w.r.t. Reverse KL-divergence, JSD, $0.3$-divergence and $0.5$-divergence, respectively. MODPO is only applicable to KL-divergence, and we report its mean of $3$ seeds. The frontier of MOD generally lies over RS.
  • Figure 5: Fine-grained RLHF. The left figure illustrates the performance of MOD and RS on $\pi_1,\pi_2$, and the right one illustrates the performance on $\pi_1^\star,\pi_2$, where $\pi_1^\star$ is obtained via reversing the sign of $Q,K$ matrices of the last two layers of $\pi_1$.
  • ...and 1 more figures

Theorems & Definitions (31)

  • Theorem 1: Informal key theorem
  • Theorem 2
  • Remark 1: Clarification
  • Theorem 3
  • Remark 2: Motivating example
  • Theorem 4: KL-divergence perspective
  • Remark 3: Interpretation of conditions
  • Definition 1: $f$-divergence fdivergencefdivergencecsiszar1fdivergencecsiszar2
  • Definition 2: Barrier function convexoptimization
  • Definition 3: Expected calibration error guo2017calibrationfDPO
  • ...and 21 more