Table of Contents
Fetching ...

NETS: A Non-Equilibrium Transport Sampler

Michael S. Albergo, Eric Vanden-Eijnden

TL;DR

This work introduces the Non-Equilibrium Transport Sampler NETS for sampling unnormalized distributions in high dimensions by augmenting annealed Langevin dynamics with a learned drift. Grounded in a Jarzynski based identity, NETS offers unbiased sampling while enabling post training tuning of diffusion to maximize effective sample size. The authors present two drift learning strategies, PINN and Action Matching, that avoid backpropagation through SDE solvers yet guarantee convergence properties and provide KL divergence controls. Demonstrations on standard benchmarks, high dimensional Gaussian mixtures, and a lattice phi4 theory show NETS surpasses existing baselines, with strong performance even near critical transitions and capabilities for multimarginal and inertial extensions.

Abstract

We propose an algorithm, termed the Non-Equilibrium Transport Sampler (NETS), to sample from unnormalized probability distributions. NETS can be viewed as a variant of annealed importance sampling (AIS) based on Jarzynski's equality, in which the stochastic differential equation used to perform the non-equilibrium sampling is augmented with an additional learned drift term that lowers the impact of the unbiasing weights used in AIS. We show that this drift is the minimizer of a variety of objective functions, which can all be estimated in an unbiased fashion without backpropagating through solutions of the stochastic differential equations governing the sampling. We also prove that some these objectives control the Kullback-Leibler divergence of the estimated distribution from its target. NETS is shown to be unbiased and, in addition, has a tunable diffusion coefficient which can be adjusted post-training to maximize the effective sample size. We demonstrate the efficacy of the method on standard benchmarks, high-dimensional Gaussian mixture distributions, and a model from statistical lattice field theory, for which it surpasses the performances of related work and existing baselines.

NETS: A Non-Equilibrium Transport Sampler

TL;DR

This work introduces the Non-Equilibrium Transport Sampler NETS for sampling unnormalized distributions in high dimensions by augmenting annealed Langevin dynamics with a learned drift. Grounded in a Jarzynski based identity, NETS offers unbiased sampling while enabling post training tuning of diffusion to maximize effective sample size. The authors present two drift learning strategies, PINN and Action Matching, that avoid backpropagation through SDE solvers yet guarantee convergence properties and provide KL divergence controls. Demonstrations on standard benchmarks, high dimensional Gaussian mixtures, and a lattice phi4 theory show NETS surpasses existing baselines, with strong performance even near critical transitions and capabilities for multimarginal and inertial extensions.

Abstract

We propose an algorithm, termed the Non-Equilibrium Transport Sampler (NETS), to sample from unnormalized probability distributions. NETS can be viewed as a variant of annealed importance sampling (AIS) based on Jarzynski's equality, in which the stochastic differential equation used to perform the non-equilibrium sampling is augmented with an additional learned drift term that lowers the impact of the unbiasing weights used in AIS. We show that this drift is the minimizer of a variety of objective functions, which can all be estimated in an unbiased fashion without backpropagating through solutions of the stochastic differential equations governing the sampling. We also prove that some these objectives control the Kullback-Leibler divergence of the estimated distribution from its target. NETS is shown to be unbiased and, in addition, has a tunable diffusion coefficient which can be adjusted post-training to maximize the effective sample size. We demonstrate the efficacy of the method on standard benchmarks, high-dimensional Gaussian mixture distributions, and a model from statistical lattice field theory, for which it surpasses the performances of related work and existing baselines.
Paper Structure (39 sections, 17 theorems, 151 equations, 4 figures, 2 tables, 1 algorithm)

This paper contains 39 sections, 17 theorems, 151 equations, 4 figures, 2 tables, 1 algorithm.

Key Result

Proposition 1

Let $(X_t,A_t)$ solve the coupled system of SDE/ODE where $\varepsilon_t\ge 0$ is a time-dependent diffusion coefficient and $W_t \in \mathbb R^d$ is the Wiener process. Then for all $t\in[0,1]$ and any test function $h:\mathbb{R}^d\to\mathbb{R}$, we have where the expectations are taken over the law of $(X_t, A_t)$.

Figures (4)

  • Figure 1: Comparison of the performance of annealed Langevin dynamics alone, transport alone, and annealed Langevin coupled with transport when sampling the 40-mode GMM from midgley2023flow. Left: Annealed Langevin run for 250 steps with $\varepsilon_t = 4.0$, failing to capture the modes with $0\%$ ESS. Center: Learning using the PINN loss and sampling with 100 steps and $\varepsilon_t = 0$ achieves an ESS of $95\%$. Right: Same learning and now sampling with $\varepsilon_t = 4.0$ achieves an ESS of $98\%$.
  • Figure 2: Demonstration of high-dimensional sampling with our method using the PINN loss in (\ref{['eq:loss:b:F']}) and a study of how diffusivity impacts performance, with and without transport. Left: NETS can achieve high ESS through transport alone, and the effect of increased diffusivity has more of a positive effect on performance with sampling than without. AIS cannot achieve ESS above $\approx 0$ in high dimension. Right: Kernel density estimates of 2-$d$ cross sections of the high-dimensional, multimodal distribution arising from the model and ground truth.
  • Figure 3: Comparison of the performance of NETS to AIS on two different settings for the study of $\varphi^4$ theory. Top row, left: 10 example generative lattice configurations with parameters $L=20$, $m^2 = -1.0$, $\lambda = 0.9$, which demarcates the phase transition to the antiferromagnetic phase. Top row, right: Performance of AIS (purple curve) vs. NETS (red curve) in terms of effective sample size over time of integration $t$, and a histogram of the average magnetization of $4000$ lattice configurations, sampled with AIS, NETS, and HMC (superposed in this order). Note that NETS is closer to the HMC target and re-weights correctly. Re-weighted AIS was not plotted because the weights were too high variance. Bottom row: Equivalent setup for $L=16$, $m^2 = -1.0$, $\lambda = 0.8$, past the phase transition and into the ordered phase. Note that the field configurations generated by NETS are either all positive across lattice sites or all negative. AIS fails to sample the correct distribution, and its weights are too high variance to be used on the histogram.
  • Figure 4: Reduction in $\mathcal{W}_2$ distance from taking the $\epsilon \to \infty$ limit in sampling with NETS. Note that the resolution of the SDE integration must increase to accommodate the higher stochasticity. Average taken over 3 sampling runs of 2000 walkers each.

Theorems & Definitions (27)

  • Proposition 1: Jarzynski equality
  • Remark 1
  • Proposition 2: Sampling with perfect additional transport.
  • Proposition 3: Nonequilibrium Transport Sampler (NETS)
  • Proposition 4: PINN objective
  • Proposition 5: KL control
  • Proposition 6: Action Matching objective
  • Remark 2
  • Proposition 7
  • Proposition 8
  • ...and 17 more