Table of Contents
Fetching ...

Learnable Chernoff Baselines for Inference-Time Alignment

Sunil Madhow, Yuchen Liang, Ness Shroff, Yingbin Liang, Yu-Xiang Wang

TL;DR

The paper addresses inference-time alignment for pretrained generative models under KL regularization by introducing Learnable Chernoff Baselines (LCBs). LCBs provide adaptive, model-agnostic rejection sampling envelopes that tilt transition kernels via learned soft-value baselines, enabling efficient sampling with provable total-variation guarantees relative to ideal tilted-rejection schemes. The framework yields end-to-end TV bounds, supports both continuous diffusion and discrete language diffusion, and demonstrates substantial compute savings in Gaussian mixtures and large-language diffusion tasks (LLaDA) without sacrificing alignment quality. This approach offers a scalable, principled path to safer, reward-guided generation at inference time, with practical impact on downstream tasks and potential dual-use considerations.

Abstract

We study inference-time reward-guided alignment for generative models. Existing methods often rely on either architecture-specific adaptations or computationally costly inference procedures. We introduce Learnable Chernoff Baselines (LCBs) as a method for efficiently and approximately sampling from the exponentially tilted kernels that arise from KL-regularized reward alignment. Using only black-box sampling access to the pretrained model, LCBs implement a form of rejection sampling with adaptively selected acceptance probabilities, which allows fine-grained control over inference-compute scaling. We establish total-variation guarantees to the ideal aligned model, and demonstrate in both continuous and discrete diffusion settings that LCB sampling closely matches ideal rejection sampling while using substantially fewer queries to the pretrained model.

Learnable Chernoff Baselines for Inference-Time Alignment

TL;DR

The paper addresses inference-time alignment for pretrained generative models under KL regularization by introducing Learnable Chernoff Baselines (LCBs). LCBs provide adaptive, model-agnostic rejection sampling envelopes that tilt transition kernels via learned soft-value baselines, enabling efficient sampling with provable total-variation guarantees relative to ideal tilted-rejection schemes. The framework yields end-to-end TV bounds, supports both continuous diffusion and discrete language diffusion, and demonstrates substantial compute savings in Gaussian mixtures and large-language diffusion tasks (LLaDA) without sacrificing alignment quality. This approach offers a scalable, principled path to safer, reward-guided generation at inference time, with practical impact on downstream tasks and potential dual-use considerations.

Abstract

We study inference-time reward-guided alignment for generative models. Existing methods often rely on either architecture-specific adaptations or computationally costly inference procedures. We introduce Learnable Chernoff Baselines (LCBs) as a method for efficiently and approximately sampling from the exponentially tilted kernels that arise from KL-regularized reward alignment. Using only black-box sampling access to the pretrained model, LCBs implement a form of rejection sampling with adaptively selected acceptance probabilities, which allows fine-grained control over inference-compute scaling. We establish total-variation guarantees to the ideal aligned model, and demonstrate in both continuous and discrete diffusion settings that LCB sampling closely matches ideal rejection sampling while using substantially fewer queries to the pretrained model.
Paper Structure (42 sections, 33 theorems, 212 equations, 10 figures, 2 tables, 3 algorithms)

This paper contains 42 sections, 33 theorems, 212 equations, 10 figures, 2 tables, 3 algorithms.

Key Result

Theorem 3.1

The total variation error between $p^*(x_0)$ and $\hat{p}_0(x_0)$ is bounded as

Figures (10)

  • Figure 1: We plot two guided language diffusion trajectories (Section \ref{['sec:llada experiments']}). LCB learns state-dependent baselines (solid) that track soft values (dotted), enhancing proposal efficiency compared to rejection sampling's (RS) global upper bound.
  • Figure 2: Gaussian mixture: LCB matches RS using $8\%$ of the proposals. It uses $14\%$ of the proposals required by Bo$N$ for comparable alignment. By chaining LCB with BoN, we can efficiently boost alignment to compensate for error in value estimation.
  • Figure 3: LLaDA: Comparison of induced reward distribution by Rejection sampling (Algorithm \ref{['alg:exact_rejection_sampling']}) and LCBs (Algorithm \ref{['alg:baseline rejection sampling']} applied to the LCB) across a variety of temperatures. LCB sampling induces the same reward statistics as RS sampling across a variety of temperatures.
  • Figure 4: Gaussian mixture: In green, we show the number of proposals required by Rejection Sampling (RS) at temperature $\alpha = 0.2$. In pink, the number required by LCBs ($\delta = 0.1$). According to Figure \ref{['fig:apdx mog samples']}, we choose a number $N$ so that Bo$N$ produces visually similar results to the other two methods at each temperature. The gray bar for BoN corresponds to this $N$.
  • Figure 5: In terms of the reward mass, there are diminishing returns for taking $\delta < 0.1$, while "effective $N$" continues to increase for smaller $\delta$.
  • ...and 5 more figures

Theorems & Definitions (59)

  • Theorem 3.1: TV error of $\hat{p}$
  • Proposition 3.2: Rejection sampling for $\hat{p}$
  • Definition 4.1
  • Lemma 4.2: TV:MGF Lemma
  • Lemma 4.3
  • Lemma 4.4
  • Proposition 4.5: Learning the LCB
  • Theorem 4.6
  • Corollary 4.7
  • Theorem 4.8
  • ...and 49 more