Table of Contents
Fetching ...

Practical and Asymptotically Exact Conditional Sampling in Diffusion Models

Luhuan Wu, Brian L. Trippe, Christian A. Naesseth, David M. Blei, John P. Cunningham

TL;DR

The paper introduces Twisted Diffusion Sampler (TDS), a practical SMC-based method to draw asymptotically exact samples from p_theta(x^0 | y) for unconditional diffusion models, avoiding task-specific conditional training. TDS uses twisting to incorporate conditioning through tractable approximations derived from denoising predictions, while maintaining theoretical guarantees of convergence as particle count grows. The authors demonstrate TDS across 2D toy problems, MNIST/class-conditional generation, and 3D protein motif-scaffolding tasks on FrameDiff, showing improved accuracy and flexibility over heuristic or purely conditional-training approaches. The work highlights the method's ability to handle inpainting, additional degrees of freedom, and Riemannian manifolds, offering a versatile framework for exact conditional diffusion in diverse domains. Limitations include computational cost and sensitivity to the quality of twisting functions, with future work aimed at improving efficiency and expanding conditioning capabilities.

Abstract

Diffusion models have been successful on a range of conditional generation tasks including molecular design and text-to-image generation. However, these achievements have primarily depended on task-specific conditional training or error-prone heuristic approximations. Ideally, a conditional generation method should provide exact samples for a broad range of conditional distributions without requiring task-specific training. To this end, we introduce the Twisted Diffusion Sampler, or TDS. TDS is a sequential Monte Carlo (SMC) algorithm that targets the conditional distributions of diffusion models through simulating a set of weighted particles. The main idea is to use twisting, an SMC technique that enjoys good computational efficiency, to incorporate heuristic approximations without compromising asymptotic exactness. We first find in simulation and in conditional image generation tasks that TDS provides a computational statistical trade-off, yielding more accurate approximations with many particles but with empirical improvements over heuristics with as few as two particles. We then turn to motif-scaffolding, a core task in protein design, using a TDS extension to Riemannian diffusion models. On benchmark test cases, TDS allows flexible conditioning criteria and often outperforms the state of the art.

Practical and Asymptotically Exact Conditional Sampling in Diffusion Models

TL;DR

The paper introduces Twisted Diffusion Sampler (TDS), a practical SMC-based method to draw asymptotically exact samples from p_theta(x^0 | y) for unconditional diffusion models, avoiding task-specific conditional training. TDS uses twisting to incorporate conditioning through tractable approximations derived from denoising predictions, while maintaining theoretical guarantees of convergence as particle count grows. The authors demonstrate TDS across 2D toy problems, MNIST/class-conditional generation, and 3D protein motif-scaffolding tasks on FrameDiff, showing improved accuracy and flexibility over heuristic or purely conditional-training approaches. The work highlights the method's ability to handle inpainting, additional degrees of freedom, and Riemannian manifolds, offering a versatile framework for exact conditional diffusion in diverse domains. Limitations include computational cost and sensitivity to the quality of twisting functions, with future work aimed at improving efficiency and expanding conditioning capabilities.

Abstract

Diffusion models have been successful on a range of conditional generation tasks including molecular design and text-to-image generation. However, these achievements have primarily depended on task-specific conditional training or error-prone heuristic approximations. Ideally, a conditional generation method should provide exact samples for a broad range of conditional distributions without requiring task-specific training. To this end, we introduce the Twisted Diffusion Sampler, or TDS. TDS is a sequential Monte Carlo (SMC) algorithm that targets the conditional distributions of diffusion models through simulating a set of weighted particles. The main idea is to use twisting, an SMC technique that enjoys good computational efficiency, to incorporate heuristic approximations without compromising asymptotic exactness. We first find in simulation and in conditional image generation tasks that TDS provides a computational statistical trade-off, yielding more accurate approximations with many particles but with empirical improvements over heuristics with as few as two particles. We then turn to motif-scaffolding, a core task in protein design, using a TDS extension to Riemannian diffusion models. On benchmark test cases, TDS allows flexible conditioning criteria and often outperforms the state of the art.
Paper Structure (58 sections, 3 theorems, 46 equations, 17 figures, 1 table, 1 algorithm)

This paper contains 58 sections, 3 theorems, 46 equations, 17 figures, 1 table, 1 algorithm.

Key Result

Theorem 1

(Informal) Let $\mathbb{P}_{K}(x^{0})=(\sum_{k^\prime}^K w_{k^\prime})^{-1} \sum_{k=1}^K w_k\delta_{x_k^{0}}(x^{0})$ denote the discrete measure defined by the particles and weights returned by alg:TDS with $K$ particles. Under regularity conditions on the twisted proposals and weighting functions,

Figures (17)

  • Figure 1: Errors of conditional mean estimations with 2 SEM error bars averaged over 25 replicates. TDS applies to all three tasks and provides increasing accuracy with more particles.
  • Figure 2: Image class-conditional generation task.
  • Figure 3: Protein motif-scaffolding case study results
  • Figure D: Errors of conditional mean estimations with 2 SEM error bars averaged over 25 replicates on mixture of Gaussians unconditional target. TDS applies to all three tasks and provides increasing accuracy with more particles.
  • Figure E: MNIST class-conditional generation. 16 randomly chosen conditional samples from 64 particles in a single run, given class $y$. From top to bottom, $y=0,1,2,3,4,5,6,8,9$.
  • ...and 12 more figures

Theorems & Definitions (4)

  • Theorem 1
  • Theorem 2
  • Theorem 3: chopin2020introduction -- Proposition 11.4
  • proof