Learning To Sample From Diffusion Models Via Inverse Reinforcement Learning
Constant Bourdrez, Alexandre Vérine, Olivier Cappé
TL;DR
This work reframes diffusion-model sampling as a finite-horizon MDP and uses state-marginal inverse reinforcement learning to automatically tune sampling-time interventions without retraining the denoiser or learning rewards. By matching expert state occupancies with $f$-divergence objectives and estimating density ratios via a discriminator, it directly optimizes sampling policies for actions like adaptive CFG, Renoise, and stochasticity injection. The approach yields improved FID, precision, and recall across CIFAR-10, FFHQ, and ImageNet, while controlling the trade-off between sample quality and efficiency through policy design and temperature. Practically, the method incurs minimal sampling-time overhead and offers a flexible, reward-free path to principled diffusion-sampling optimization with interpretable divergences guiding precision-recall behavior.
Abstract
Diffusion models generate samples through an iterative denoising process, guided by a neural network. While training the denoiser on real-world data is computationally demanding, the sampling procedure itself is more flexible. This adaptability serves as a key lever in practice, enabling improvements in both the quality of generated samples and the efficiency of the sampling process. In this work, we introduce an inverse reinforcement learning framework for learning sampling strategies without retraining the denoiser. We formulate the diffusion sampling procedure as a discrete-time finite-horizon Markov Decision Process, where actions correspond to optional modifications of the sampling dynamics. To optimize action scheduling, we avoid defining an explicit reward function. Instead, we directly match the target behavior expected from the sampler using policy gradient techniques. We provide experimental evidence that this approach can improve the quality of samples generated by pretrained diffusion models and automatically tune sampling hyperparameters.
