Table of Contents
Fetching ...

Learning To Sample From Diffusion Models Via Inverse Reinforcement Learning

Constant Bourdrez, Alexandre Vérine, Olivier Cappé

TL;DR

This work reframes diffusion-model sampling as a finite-horizon MDP and uses state-marginal inverse reinforcement learning to automatically tune sampling-time interventions without retraining the denoiser or learning rewards. By matching expert state occupancies with $f$-divergence objectives and estimating density ratios via a discriminator, it directly optimizes sampling policies for actions like adaptive CFG, Renoise, and stochasticity injection. The approach yields improved FID, precision, and recall across CIFAR-10, FFHQ, and ImageNet, while controlling the trade-off between sample quality and efficiency through policy design and temperature. Practically, the method incurs minimal sampling-time overhead and offers a flexible, reward-free path to principled diffusion-sampling optimization with interpretable divergences guiding precision-recall behavior.

Abstract

Diffusion models generate samples through an iterative denoising process, guided by a neural network. While training the denoiser on real-world data is computationally demanding, the sampling procedure itself is more flexible. This adaptability serves as a key lever in practice, enabling improvements in both the quality of generated samples and the efficiency of the sampling process. In this work, we introduce an inverse reinforcement learning framework for learning sampling strategies without retraining the denoiser. We formulate the diffusion sampling procedure as a discrete-time finite-horizon Markov Decision Process, where actions correspond to optional modifications of the sampling dynamics. To optimize action scheduling, we avoid defining an explicit reward function. Instead, we directly match the target behavior expected from the sampler using policy gradient techniques. We provide experimental evidence that this approach can improve the quality of samples generated by pretrained diffusion models and automatically tune sampling hyperparameters.

Learning To Sample From Diffusion Models Via Inverse Reinforcement Learning

TL;DR

This work reframes diffusion-model sampling as a finite-horizon MDP and uses state-marginal inverse reinforcement learning to automatically tune sampling-time interventions without retraining the denoiser or learning rewards. By matching expert state occupancies with -divergence objectives and estimating density ratios via a discriminator, it directly optimizes sampling policies for actions like adaptive CFG, Renoise, and stochasticity injection. The approach yields improved FID, precision, and recall across CIFAR-10, FFHQ, and ImageNet, while controlling the trade-off between sample quality and efficiency through policy design and temperature. Practically, the method incurs minimal sampling-time overhead and offers a flexible, reward-free path to principled diffusion-sampling optimization with interpretable divergences guiding precision-recall behavior.

Abstract

Diffusion models generate samples through an iterative denoising process, guided by a neural network. While training the denoiser on real-world data is computationally demanding, the sampling procedure itself is more flexible. This adaptability serves as a key lever in practice, enabling improvements in both the quality of generated samples and the efficiency of the sampling process. In this work, we introduce an inverse reinforcement learning framework for learning sampling strategies without retraining the denoiser. We formulate the diffusion sampling procedure as a discrete-time finite-horizon Markov Decision Process, where actions correspond to optional modifications of the sampling dynamics. To optimize action scheduling, we avoid defining an explicit reward function. Instead, we directly match the target behavior expected from the sampler using policy gradient techniques. We provide experimental evidence that this approach can improve the quality of samples generated by pretrained diffusion models and automatically tune sampling hyperparameters.
Paper Structure (21 sections, 1 theorem, 24 equations, 7 figures, 5 tables, 2 algorithms)

This paper contains 21 sections, 1 theorem, 24 equations, 7 figures, 5 tables, 2 algorithms.

Key Result

Theorem 4.1

Let $\pi_\theta$ be a policy parameterized by $\theta$ and let $\mathcal{L}_f(\theta) = \mathcal{D}_f( \mu_E \Vert \mu_\theta)$ be the $f$-divergence between the expert occupancy measure $\mu_E$ and the occupancy measure $\mu_\theta$ induced by the policy $\pi_\theta$. Then, the gradient of $\mathca where $A_t = \sum_{t'\geq t} h_f\!\left( \frac{\mu_E(s_{t'})}{\mu_\theta(s_{t'})} \right)$ is the l

Figures (7)

  • Figure 1: FFHQ samples generated with fixed guidance $\omega=0.1$ (left) and with adaptive guidance optimized with rKL (right).
  • Figure 2: Learned classifier-free guidance profiles $\omega(x,\sigma)$ for different $f$-divergences. Higher guidance improves precision but reduces diversity. We report the mean as well as the $0.25$ and $0.75$ quantiles across samples.
  • Figure 3: Learned stochasticity injection profiles $\gamma_{\mathrm{EDM}}(x,\sigma)$ on CIFAR-10 $32 \times 32$ for different $f$-divergences.
  • Figure 4: Impact of temperature on Precision, Recall, and NFE for learned renoise policies on CIFAR-10 $32 \times 32$. Higher temperature increases diversity and NFE.
  • Figure 5: Comparison of $x$-dependent and noise-level–dependent policies for adaptive CFG on CIFAR-10 using KL and Reverse-KL divergence minimization. The $x$-dependent policy achieves better FID, Precision, and Recall by selectively applying guidance based on the current sample state. The left profile corresponds to a policy depending only on noise level $\sigma$, while the right profile uses both $x$ and $\sigma$. It is clear that the policy independent of $x$ produces suboptimal results.
  • ...and 2 more figures

Theorems & Definitions (5)

  • Definition 3.1: Occupancy Measure
  • Definition 3.2: $f$-divergence
  • Theorem 4.1
  • proof
  • proof