Table of Contents
Fetching ...

Amortizing intractable inference in diffusion models for vision, language, and control

Siddarth Venkatraman, Moksh Jain, Luca Scimeca, Minsu Kim, Marcin Sendera, Mohsin Hasan, Luke Rowe, Sarthak Mittal, Pablo Lemos, Emmanuel Bengio, Alexandre Adam, Jarrid Rector-Brooks, Yoshua Bengio, Glen Berseth, Nikolay Malkin

TL;DR

This paper tackles the challenge of sampling from posteriors defined by a diffusion prior and an arbitrary constraint, i.e., p^{post}(x) ∝ p(x) r(x). It introduces Relative Trajectory Balance (RTB), an asymptotically unbiased objective derived from the Generative Flow Network view, enabling amortized, off-policy posterior sampling under diffusion priors and extending to discrete/sequence-generation settings. RTB provides a data-free alternative to classifier-guided or KL-regularized methods, with theoretical guarantees and practical benefits for mode coverage, conditioning, and scalability. Empirically, RTB achieves competitive or superior results across vision, language, and offline-control tasks, including text infilling with discrete diffusion, classifier-guided image generation, and reinforcement learning with diffusion-based behavior priors. The approach offers a flexible, general framework for unbiased posterior inference with diffusion priors and holds promise for inverse problems, diffusion-guided decision making, and scalable amortized inference in complex, high-dimensional domains.

Abstract

Diffusion models have emerged as effective distribution estimators in vision, language, and reinforcement learning, but their use as priors in downstream tasks poses an intractable posterior inference problem. This paper studies amortized sampling of the posterior over data, $\mathbf{x}\sim p^{\rm post}(\mathbf{x})\propto p(\mathbf{x})r(\mathbf{x})$, in a model that consists of a diffusion generative model prior $p(\mathbf{x})$ and a black-box constraint or likelihood function $r(\mathbf{x})$. We state and prove the asymptotic correctness of a data-free learning objective, relative trajectory balance, for training a diffusion model that samples from this posterior, a problem that existing methods solve only approximately or in restricted cases. Relative trajectory balance arises from the generative flow network perspective on diffusion models, which allows the use of deep reinforcement learning techniques to improve mode coverage. Experiments illustrate the broad potential of unbiased inference of arbitrary posteriors under diffusion priors: in vision (classifier guidance), language (infilling under a discrete diffusion LLM), and multimodal data (text-to-image generation). Beyond generative modeling, we apply relative trajectory balance to the problem of continuous control with a score-based behavior prior, achieving state-of-the-art results on benchmarks in offline reinforcement learning.

Amortizing intractable inference in diffusion models for vision, language, and control

TL;DR

This paper tackles the challenge of sampling from posteriors defined by a diffusion prior and an arbitrary constraint, i.e., p^{post}(x) ∝ p(x) r(x). It introduces Relative Trajectory Balance (RTB), an asymptotically unbiased objective derived from the Generative Flow Network view, enabling amortized, off-policy posterior sampling under diffusion priors and extending to discrete/sequence-generation settings. RTB provides a data-free alternative to classifier-guided or KL-regularized methods, with theoretical guarantees and practical benefits for mode coverage, conditioning, and scalability. Empirically, RTB achieves competitive or superior results across vision, language, and offline-control tasks, including text infilling with discrete diffusion, classifier-guided image generation, and reinforcement learning with diffusion-based behavior priors. The approach offers a flexible, general framework for unbiased posterior inference with diffusion priors and holds promise for inverse problems, diffusion-guided decision making, and scalable amortized inference in complex, high-dimensional domains.

Abstract

Diffusion models have emerged as effective distribution estimators in vision, language, and reinforcement learning, but their use as priors in downstream tasks poses an intractable posterior inference problem. This paper studies amortized sampling of the posterior over data, , in a model that consists of a diffusion generative model prior and a black-box constraint or likelihood function . We state and prove the asymptotic correctness of a data-free learning objective, relative trajectory balance, for training a diffusion model that samples from this posterior, a problem that existing methods solve only approximately or in restricted cases. Relative trajectory balance arises from the generative flow network perspective on diffusion models, which allows the use of deep reinforcement learning techniques to improve mode coverage. Experiments illustrate the broad potential of unbiased inference of arbitrary posteriors under diffusion priors: in vision (classifier guidance), language (infilling under a discrete diffusion LLM), and multimodal data (text-to-image generation). Beyond generative modeling, we apply relative trajectory balance to the problem of continuous control with a score-based behavior prior, achieving state-of-the-art results on benchmarks in offline reinforcement learning.
Paper Structure (18 sections, 14 equations, 6 figures, 2 tables)