Table of Contents
Fetching ...

Learning the Stein Discrepancy for Training and Evaluating Energy-Based Models without Sampling

Will Grathwohl, Kuan-Chieh Wang, Jorn-Henrik Jacobsen, David Duvenaud, Richard Zemel

TL;DR

This paper introduces LSD, a neural-critic-based Learned Stein Discrepancy, to directly compare data density p with an unnormalized model q without sampling. By leveraging Stein's identity and an efficiently estimable objective via Hutchinson’s trick, LSD enables both goodness-of-fit testing and training of energy-based models at high dimensionality. The authors demonstrate that LSD matches or outperforms kernel-based Stein methods in GoF and model evaluation, and enables sampler-free training of EBMs that scales to complex densities, including RBMs, ICA, and deep flows. The approach provides a unified, scalable framework for evaluating and learning unnormalized densities with practical impact on robustness, calibration, and high-dimensional density modeling.

Abstract

We present a new method for evaluating and training unnormalized density models. Our approach only requires access to the gradient of the unnormalized model's log-density. We estimate the Stein discrepancy between the data density $p(x)$ and the model density $q(x)$ defined by a vector function of the data. We parameterize this function with a neural network and fit its parameters to maximize the discrepancy. This yields a novel goodness-of-fit test which outperforms existing methods on high dimensional data. Furthermore, optimizing $q(x)$ to minimize this discrepancy produces a novel method for training unnormalized models which scales more gracefully than existing methods. The ability to both learn and compare models is a unique feature of the proposed method.

Learning the Stein Discrepancy for Training and Evaluating Energy-Based Models without Sampling

TL;DR

This paper introduces LSD, a neural-critic-based Learned Stein Discrepancy, to directly compare data density p with an unnormalized model q without sampling. By leveraging Stein's identity and an efficiently estimable objective via Hutchinson’s trick, LSD enables both goodness-of-fit testing and training of energy-based models at high dimensionality. The authors demonstrate that LSD matches or outperforms kernel-based Stein methods in GoF and model evaluation, and enables sampler-free training of EBMs that scales to complex densities, including RBMs, ICA, and deep flows. The approach provides a unified, scalable framework for evaluating and learning unnormalized densities with practical impact on robustness, calibration, and high-dimensional density modeling.

Abstract

We present a new method for evaluating and training unnormalized density models. Our approach only requires access to the gradient of the unnormalized model's log-density. We estimate the Stein discrepancy between the data density and the model density defined by a vector function of the data. We parameterize this function with a neural network and fit its parameters to maximize the discrepancy. This yields a novel goodness-of-fit test which outperforms existing methods on high dimensional data. Furthermore, optimizing to minimize this discrepancy produces a novel method for training unnormalized models which scales more gracefully than existing methods. The ability to both learn and compare models is a unique feature of the proposed method.

Paper Structure

This paper contains 40 sections, 24 equations, 15 figures, 3 algorithms.

Figures (15)

  • Figure 1: Density models trained with approximate MCMC samplers can fail to match the data density while still generating high-quality samples. Samples from approximate MCMC samplers follow a different distribution than the density they are applied to. It is this induced distribution which is trained to match the data. In contrast, our approach $\rm{\bf{LSD}}$ directly matches the model density to the data density without reliance on a sampler.
  • Figure 2: Cutting out the "middle-man" of approximate sampling can lead to simpler training and evaluation that is tied directly to the quality of our model and is not obfuscated by the parameters of an MCMC sampler.
  • Figure 3: Training a neural net to estimate $\mathbf{S}(p, q)$ on 100-dimensional data. In both cases, a near-optimal critic is learned and the true discrepancy is closely approximated.
  • Figure 4: Density models trained using LSD. Top: Data. Bottom: Learned densities.
  • Figure 5: Hypothesis testing results. Test confidence $0.05$. Perturbed RBMs of increasing data dimension. Perturbation magnitude on the $x$-axis, rejection rate on the $y$-axis. Number of datapoints $n=1000$. Ideal behavior is a 5% rejection when perturbation is 0 and close to 100% rejection otherwise. In high dimensions our linear-time $\rm{\bf{LSD}}$ matches the performance of the quadratic-time KSD.
  • ...and 10 more figures