Steering diffusion models with quadratic rewards: a fine-grained analysis
Ankur Moitra, Andrej Risteski, Dhruv Rohatgi
TL;DR
This work analyzes the problem of sampling from a diffusion-based model tilted by a quadratic reward, $p^\star(x) \propto p(x) \exp(r(x))$ with $r(x)=x^\top A x + b^\top x$, and reveals a sharp tractability landscape driven by the rank and definiteness of $A$. The authors prove that linear rewards are efficiently samplable, while negative-definite tilts are intractable even for rank-1, via a PARTITION-based reduction. They then show that low-rank positive-definite tilts are tractable through a novel Hubbard–Stratonovich lifting that reduces the problem to sampling linear tilts plus normalization estimation, with a poly-time algorithm. This PSD-tilt approach yields a practical algorithmic pathway for inference-time tasks such as guided sampling and posterior inference under linear measurements. Overall, the paper clarifies when diffusion-model steering is computationally feasible and provides concrete methods to exploit low-rank structure in rewards.
Abstract
Inference-time algorithms are an emerging paradigm in which pre-trained models are used as subroutines to solve downstream tasks. Such algorithms have been proposed for tasks ranging from inverse problems and guided image generation to reasoning. However, the methods currently deployed in practice are heuristics with a variety of failure modes -- and we have very little understanding of when these heuristics can be efficiently improved. In this paper, we consider the task of sampling from a reward-tilted diffusion model -- that is, sampling from $p^{\star}(x) \propto p(x) \exp(r(x))$ -- given a reward function $r$ and pre-trained diffusion oracle for $p$. We provide a fine-grained analysis of the computational tractability of this task for quadratic rewards $r(x) = x^\top A x + b^\top x$. We show that linear-reward tilts are always efficiently sampleable -- a simple result that seems to have gone unnoticed in the literature. We use this as a building block, along with a conceptually new ingredient -- the Hubbard-Stratonovich transform -- to provide an efficient algorithm for sampling from low-rank positive-definite quadratic tilts, i.e. $r(x) = x^\top A x$ where $A$ is positive-definite and of rank $O(1)$. For negative-definite tilts, i.e. $r(x) = - x^\top A x$ where $A$ is positive-definite, we prove that the problem is intractable even if $A$ is of rank 1 (albeit with exponentially-large entries).
