Table of Contents
Fetching ...

Test-Time Alignment of LLMs via Sampling-Based Optimal Control in pre-logit space

Sekitoshi Kanai, Tsukasa Yoshida, Hiroshi Takahashi, Haru Kuroki, Kazumune Hashimoto

TL;DR

The paper tackles the computational burden of RLHF by introducing a training-free test-time alignment method, AISP, which perturbs pre-logits with Gaussian noise and optimizes the mean perturbations to maximize rewards. It builds a stochastic control framework, derives a free-energy bound, and uses adaptive importance sampling to approximate the optimal perturbation distribution, connecting to BoN and softmax via a Gaussian pre-logit assumption. Empirical results show AISP achieves higher rewards with fewer samples than BoN and outperforms RE-Control, with Batched AISP offering scalable gains. The approach offers a practical, scalable pathway to better align LLMs at inference time without additional training data or fine-tuning.

Abstract

Test-time alignment of large language models (LLMs) attracts attention because fine-tuning LLMs requires high computational costs. In this paper, we propose a new test-time alignment method called adaptive importance sampling on pre-logits (AISP) on the basis of the sampling-based model predictive control with the stochastic control input. AISP applies the Gaussian perturbation into pre-logits, which are outputs of the penultimate layer, so as to maximize expected rewards with respect to the mean of the perturbation. We demonstrate that the optimal mean is obtained by importance sampling with sampled rewards. AISP outperforms best-of-n sampling in terms of rewards over the number of used samples and achieves higher rewards than other reward-based test-time alignment methods.

Test-Time Alignment of LLMs via Sampling-Based Optimal Control in pre-logit space

TL;DR

The paper tackles the computational burden of RLHF by introducing a training-free test-time alignment method, AISP, which perturbs pre-logits with Gaussian noise and optimizes the mean perturbations to maximize rewards. It builds a stochastic control framework, derives a free-energy bound, and uses adaptive importance sampling to approximate the optimal perturbation distribution, connecting to BoN and softmax via a Gaussian pre-logit assumption. Empirical results show AISP achieves higher rewards with fewer samples than BoN and outperforms RE-Control, with Batched AISP offering scalable gains. The approach offers a practical, scalable pathway to better align LLMs at inference time without additional training data or fine-tuning.

Abstract

Test-time alignment of large language models (LLMs) attracts attention because fine-tuning LLMs requires high computational costs. In this paper, we propose a new test-time alignment method called adaptive importance sampling on pre-logits (AISP) on the basis of the sampling-based model predictive control with the stochastic control input. AISP applies the Gaussian perturbation into pre-logits, which are outputs of the penultimate layer, so as to maximize expected rewards with respect to the mean of the perturbation. We demonstrate that the optimal mean is obtained by importance sampling with sampled rewards. AISP outperforms best-of-n sampling in terms of rewards over the number of used samples and achieves higher rewards than other reward-based test-time alignment methods.

Paper Structure

This paper contains 46 sections, 6 theorems, 43 equations, 8 figures, 5 tables, 1 algorithm.

Key Result

Theorem 3.1

Free energy Eq. (FreeEq) satisfies $-\lambda F(r,p,\bm{\mathrm{x}},\lambda) \leq J(\bm{\mathrm{x}},U)$ and the equality holds if where $\eta$ is a normalization constant given by $\eta=\int_{\mathbb{R}^{d\times \tau}}\mathrm{exp}\left(\frac{1}{\lambda} r(\bm{\mathrm{x}},\bm{\mathrm{y}}(V))\right)p(V)dV$.

Figures (8)

  • Figure 1: Illustration of AISP. $n$ input trajectries $\{\{\bm{v}^i_t\}_{t=1}^{\tau}\}_{i=1}^{n}$ are sampled from $\mathcal{N}(\bm{u}_t, \sigma^2 \bm{I})$. The input $\bm{v}^i_t$ is added to the pre-logit $\bm{z}_t$, which is obtained by applying LLMs to the past tokens $\bm{y}^i_{<t}$. The $t$-th token $y_t^i$ is sampled and concatenated with the past tokens $\bm{y}^i_{<t}$. When $y^i_t$ is the end-of-sequence token, the rewards of $\{\mathbf{y}(V^i)\}^{n}_{i=1}$ are evaluated and used in adaptive importance sampling for $\bm{u}_t$.
  • Figure 2: Schematic illustration of computational costs (vertical: parallelism; horizontal: iterations) of Batched AISP and BoN for $D$ prompts. When $\kappa D/b\!=\!D, nb\!=\!N \Leftrightarrow \kappa\!=\!b, n\!=\!N/b$, Batched AISP and BoN have almost the same sequential and parallel computational cost.
  • Figure 3: Sample efficiency to improve rewards: reward curve against $k$ iterations. For each iteration, both methods generate 32 samples.
  • Figure 4: Rewards of Batched AISP and BoN using Llama&UltraRM on SHP for five trials.
  • Figure 5: Prompts for GPT-4 evaluation. {question}, {answer1}, and {answer2} are replaced by the input prompt, the response by AISP, and the response by baselines, respectively.
  • ...and 3 more figures

Theorems & Definitions (9)

  • Theorem 3.1
  • Theorem 3.2
  • Theorem 3.3
  • Theorem
  • proof
  • Theorem
  • proof
  • Theorem
  • proof