Table of Contents
Fetching ...

InfAlign: Inference-aware language model alignment

Ananth Balashankar, Ziteng Sun, Jonathan Berant, Jacob Eisenstein, Michael Collins, Adrian Hutter, Jong Lee, Chirag Nagpal, Flavien Prost, Aradhana Sinha, Ananda Theertha Suresh, Ahmad Beirami

TL;DR

<3-5 sentence high-level summary> InfAlign tackles the mismatch between RLHF training and inference-time decoding by formulating an inference-aware alignment objective, which can be solved via a transformed reward. The authors prove the existence of an appropriate reward transformation that aligns the training objective with inference-time win rates and introduce InfAlign-CTRL, a practical Calibrate- and-transform RL method that includes reward calibration and exponential tilting-based transformations. Empirically, InfAlign-CTRL yields 3–8% gains in inference-time win rate for BoN and WoN procedures across helpfulness, harmlessness, and summarization tasks, while preserving strong standard win-rate performance. The approach provides a principled, model-agnostic framework to tailor alignment to specific inference-time strategies and mitigates reward hacking through calibration.

Abstract

Language model alignment is a critical step in training modern generative language models. Alignment targets to improve win rate of a sample from the aligned model against the base model. Today, we are increasingly using inference-time algorithms (e.g., Best-of-N, controlled decoding, tree search) to decode from language models rather than standard sampling. We show that this train/test mismatch makes standard RLHF framework sub-optimal in view of such inference-time methods. To this end, we propose a framework for inference-aware alignment (InfAlign), which aims to optimize inference-time win rate of the aligned policy against the base model. We prove that for any inference-time decoding procedure, the optimal aligned policy is the solution to the standard RLHF problem with a transformation of the reward. This motivates us to provide the calibrate-and-transform RL (InfAlign-CTRL) algorithm to solve this problem, which involves a reward calibration step and a KL-regularized reward maximization step with a transformation of the calibrated reward. For best-of-N sampling and best-of-N jailbreaking, we propose specific transformations offering up to 3-8% improvement on inference-time win rates. Finally, we also show that our proposed reward calibration method is a strong baseline for optimizing standard win rate.

InfAlign: Inference-aware language model alignment

TL;DR

<3-5 sentence high-level summary> InfAlign tackles the mismatch between RLHF training and inference-time decoding by formulating an inference-aware alignment objective, which can be solved via a transformed reward. The authors prove the existence of an appropriate reward transformation that aligns the training objective with inference-time win rates and introduce InfAlign-CTRL, a practical Calibrate- and-transform RL method that includes reward calibration and exponential tilting-based transformations. Empirically, InfAlign-CTRL yields 3–8% gains in inference-time win rate for BoN and WoN procedures across helpfulness, harmlessness, and summarization tasks, while preserving strong standard win-rate performance. The approach provides a principled, model-agnostic framework to tailor alignment to specific inference-time strategies and mitigates reward hacking through calibration.

Abstract

Language model alignment is a critical step in training modern generative language models. Alignment targets to improve win rate of a sample from the aligned model against the base model. Today, we are increasingly using inference-time algorithms (e.g., Best-of-N, controlled decoding, tree search) to decode from language models rather than standard sampling. We show that this train/test mismatch makes standard RLHF framework sub-optimal in view of such inference-time methods. To this end, we propose a framework for inference-aware alignment (InfAlign), which aims to optimize inference-time win rate of the aligned policy against the base model. We prove that for any inference-time decoding procedure, the optimal aligned policy is the solution to the standard RLHF problem with a transformation of the reward. This motivates us to provide the calibrate-and-transform RL (InfAlign-CTRL) algorithm to solve this problem, which involves a reward calibration step and a KL-regularized reward maximization step with a transformation of the calibrated reward. For best-of-N sampling and best-of-N jailbreaking, we propose specific transformations offering up to 3-8% improvement on inference-time win rates. Finally, we also show that our proposed reward calibration method is a strong baseline for optimizing standard win rate.
Paper Structure (49 sections, 15 theorems, 65 equations, 12 figures, 1 algorithm)

This paper contains 49 sections, 15 theorems, 65 equations, 12 figures, 1 algorithm.

Key Result

Lemma 1

For any base policy ${\pi_{\rm ref}}$, reward model $r$, inference-time procedure $\mathcal{T}$, and $\beta > 0$, there exists a reward function $\mathcal{R}_{r, {\pi_{\rm ref}}, \mathcal{T}}$ such that the maximizer of eqn:kl-rl-tr solves the optimization problem in eqn:kl-constrained-ppwr (def:kl-

Figures (12)

  • Figure 1: Given an inference-time procedure such as Best-of-$N$, standard RLHF suffers from a train/test mismatch between train-time policy $\pi$ and inference-time policy $\mathcal{T}_{\pi}.$InfAlign bridges the gap by optimizing a policy-transformed reward $\mathcal{R}$, yielding a policy, $\pi^*,$ that is optimized for inference under $\mathcal{T}_{\pi^*}$.
  • Figure 2: Best-of-$N$ (left) and Worst-of-$N$ (right) win rate vs KL tradeoff curves for $N = 4$ with different transformation functions.
  • Figure 3: Results on reward models trained on the Anthropic helpfulness preference dataset. Scatter plot of reward scores and reward ranks on a random sample of 10 prompts in the Anthropic helpfulness dataset. Note that the model shows miscalibration on most prompts, with the degree of miscalibration varying by prompt.
  • Figure 4: ( Top row) Standard win rate comparison of InfAlign-CTRL using identity transformation with other SOTA methods on Anthropic helpfulness, harmlessness, and Reddit summarization dataset. ( Bottom row) Best/Worst-of-$N$ win rate comparison of InfAlign-CTRL using exponential reward transformation. We report win rate against on the test split as measured by the PaLM-2 M reward model trained on the corresponding datasets.
  • Figure 5: Best-of-$N$ win rate comparison on the Anthropic helpfulness dataset with $N = 2, 32$ for different alignment methods.
  • ...and 7 more figures

Theorems & Definitions (27)

  • Definition 1: KL-regularized RL
  • Definition 2: Calibrated reward
  • Definition 3: Standard win rate
  • Definition 4: Inference-time win rate
  • Definition 5: InfAlign
  • Lemma 1
  • Theorem 1: Characterization of InfAlign solution
  • Corollary 1
  • Lemma 2: Calibration is a bounded monotone increasing transformation of reward
  • Lemma 3: Calibration is invariant under monotone increasing transformations
  • ...and 17 more