Training Neural Samplers with Reverse Diffusive KL Divergence
Jiajun He, Wenlin Chen, Mingtian Zhang, David Barber, José Miguel Hernández-Lobato
TL;DR
This work introduces reverse diffusive KL divergence (DiKL) to train neural samplers for unnormalized target distributions, addressing the mode-seeking tendency of traditional reverse KL by diffusing both model and target across multiple Gaussian kernels. The method combines denoising score matching (DSM) for model-score estimation with Mixed Score Identity (MSI) for noisy-target scores, enabling a practical, gradient-based training routine for implicit generators. Applied to Boltzmann generators, the DiKL framework leverages equivariant architectures to respect invariances, achieving competitive or superior mode coverage and sampling efficiency on multi-modal energy landscapes. The approach delivers fast, one-shot sampling with strong mass-covering properties, offering a scalable alternative to diffusion-based and flow-based samplers while highlighting areas for future enhancement, such as combining with multi-step strategies and improving posterior sampling stability.
Abstract
Training generative models to sample from unnormalized density functions is an important and challenging task in machine learning. Traditional training methods often rely on the reverse Kullback-Leibler (KL) divergence due to its tractability. However, the mode-seeking behavior of reverse KL hinders effective approximation of multi-modal target distributions. To address this, we propose to minimize the reverse KL along diffusion trajectories of both model and target densities. We refer to this objective as the reverse diffusive KL divergence, which allows the model to capture multiple modes. Leveraging this objective, we train neural samplers that can efficiently generate samples from the target distribution in one step. We demonstrate that our method enhances sampling performance across various Boltzmann distributions, including both synthetic multi-modal densities and n-body particle systems.
