Table of Contents
Fetching ...

Probabilistic Inference in Language Models via Twisted Sequential Monte Carlo

Stephen Zhao, Rob Brekelmans, Alireza Makhzani, Roger Grosse

TL;DR

This work reframes large-language-model inference as sampling from an unnormalized target defined by a terminal potential, and introduces Twisted Sequential Monte Carlo to steer generation by learning per-step twist functions. The authors develop Contrastive Twist Learning to train these twists and connect the framework to soft reinforcement learning, density-ratio methods, and RL-inspired losses. They further introduce bidirectional log-partition-function bounds that enable robust evaluation of inference techniques and KL-based diagnostics, including exact BDMC-based sampling when possible. The approach is demonstrated on tasks including toxic-story generation, sentiment-varied generation, and infilling, providing both improved sampling efficiency and principled evaluation tools for controlled generation. Overall, the paper offers a principled probabilistic framework for inference and evaluation in language models with practical implications for safety, red-teaming, and controllable generation.

Abstract

Numerous capability and safety techniques of Large Language Models (LLMs), including RLHF, automated red-teaming, prompt engineering, and infilling, can be cast as sampling from an unnormalized target distribution defined by a given reward or potential function over the full sequence. In this work, we leverage the rich toolkit of Sequential Monte Carlo (SMC) for these probabilistic inference problems. In particular, we use learned twist functions to estimate the expected future value of the potential at each timestep, which enables us to focus inference-time computation on promising partial sequences. We propose a novel contrastive method for learning the twist functions, and establish connections with the rich literature of soft reinforcement learning. As a complementary application of our twisted SMC framework, we present methods for evaluating the accuracy of language model inference techniques using novel bidirectional SMC bounds on the log partition function. These bounds can be used to estimate the KL divergence between the inference and target distributions in both directions. We apply our inference evaluation techniques to show that twisted SMC is effective for sampling undesirable outputs from a pretrained model (a useful component of harmlessness training and automated red-teaming), generating reviews with varied sentiment, and performing infilling tasks.

Probabilistic Inference in Language Models via Twisted Sequential Monte Carlo

TL;DR

This work reframes large-language-model inference as sampling from an unnormalized target defined by a terminal potential, and introduces Twisted Sequential Monte Carlo to steer generation by learning per-step twist functions. The authors develop Contrastive Twist Learning to train these twists and connect the framework to soft reinforcement learning, density-ratio methods, and RL-inspired losses. They further introduce bidirectional log-partition-function bounds that enable robust evaluation of inference techniques and KL-based diagnostics, including exact BDMC-based sampling when possible. The approach is demonstrated on tasks including toxic-story generation, sentiment-varied generation, and infilling, providing both improved sampling efficiency and principled evaluation tools for controlled generation. Overall, the paper offers a principled probabilistic framework for inference and evaluation in language models with practical implications for safety, red-teaming, and controllable generation.

Abstract

Numerous capability and safety techniques of Large Language Models (LLMs), including RLHF, automated red-teaming, prompt engineering, and infilling, can be cast as sampling from an unnormalized target distribution defined by a given reward or potential function over the full sequence. In this work, we leverage the rich toolkit of Sequential Monte Carlo (SMC) for these probabilistic inference problems. In particular, we use learned twist functions to estimate the expected future value of the potential at each timestep, which enables us to focus inference-time computation on promising partial sequences. We propose a novel contrastive method for learning the twist functions, and establish connections with the rich literature of soft reinforcement learning. As a complementary application of our twisted SMC framework, we present methods for evaluating the accuracy of language model inference techniques using novel bidirectional SMC bounds on the log partition function. These bounds can be used to estimate the KL divergence between the inference and target distributions in both directions. We apply our inference evaluation techniques to show that twisted SMC is effective for sampling undesirable outputs from a pretrained model (a useful component of harmlessness training and automated red-teaming), generating reviews with varied sentiment, and performing infilling tasks.
Paper Structure (106 sections, 12 theorems, 116 equations, 4 figures, 5 tables, 2 algorithms)

This paper contains 106 sections, 12 theorems, 116 equations, 4 figures, 5 tables, 2 algorithms.

Key Result

Proposition 3.2

For a given target distribution $\sigma(\mathbf{s} _{1:T})$ in eq:posterior, the optimal twist functions $\psi^*_{t}(\mathbf{s} _{1:t})$ (in regions where $p_0(\mathbf{s} _{1:t})>0$) correspond to Up to a constant independent of $\mathbf{s} _{1:t}$, the optimal twists are and satisfy the recursion

Figures (4)

  • Figure 1: Illustrative example of SIS and (Twisted) SMC for sampling book reviews conditioned on positive sentiment $\phi(\mathbf{s} _{1:T})$. SIS only performs resampling after observing the entire sequence, while SMC can kill or clone partial sequences $\mathbf{s} _{1:t}$ based on incremental importance weights induced by twist functions ${\psi_{t}}(\mathbf{s} _{1:t})$. Green/red indicate high/low importance weights at each incremental step of SMC, or at the final step of SIS. For SMC with the base model proposal $p_0$ and the optimal twists, the incremental weights $\psi^*_{t}/\psi^*_{t-1}$ (\ref{['alg:smc']} or \ref{['eq:incremental_bg']}) are directly correlated with sentiment.
  • Figure 2: Comparison of SIS (IWAE) and SMC bounds on $\log \mathcal{Z}_{\sigma}$ for base proposal $p_0$ and twist-induced proposal $q_{}^{\pi}$, with twists learned with CTL. With the twist-induced proposal, both SIS and SMC bounds are tight; with the base proposal, resampling with learned twists is needed. Resampling based on ESS instead of every-step resampling yields similar results.
  • Figure 2: Toxicity (\ref{['sec:toxclass']})
  • Figure 3: Graphical Models for extended state-space proposal and target distributions which result in the bidirectional SMC bounds. We show density evaluation in the proposal and target for a fixed set of $\{ s_t^k, \omega_t^k \}_{k=1, t= 1}^{3,2}$. We let the size of the circles reflect the (hypothetical) importance weights of sequences $\mathbf{s} _{1:t}^{\boldsymbol{\bar{h}}_{t}^k}$ and $\omega_t^k$ reflect the (hypothetical) results of resampling with these weights. In $(b)$, we assume fixed $j_{T+1}= j_3 = 1$ as in the text, with $\omega_2^1 = 2$.

Theorems & Definitions (24)

  • Definition 3.1: Twisted (Intermediate) Targets
  • Proposition 3.2: Optimal Twists
  • Proposition 3.2
  • Proposition 5.0
  • Definition 1.1: Optimal Twisted SMC Sampling
  • Proposition 1.2: Optimal SMC yields Exact Partition Function Estimation
  • proof
  • Proposition 1.3: Optimality Conditions
  • proof
  • Proposition 1.4: Optimal Intermediate Target Distributions
  • ...and 14 more