Table of Contents
Fetching ...

Accelerated Parallel Tempering via Neural Transports

Leo Zhang, Peter Potaptchik, Jiajun He, Yuanqi Du, Arnaud Doucet, Francisco Vargas, Hai-Dang Dau, Saifuddin Syed

Abstract

Markov Chain Monte Carlo (MCMC) algorithms are essential tools in computational statistics for sampling from unnormalised probability distributions, but can be fragile when targeting high-dimensional, multimodal, or complex target distributions. Parallel Tempering (PT) enhances MCMC's sample efficiency through annealing and parallel computation, propagating samples from tractable reference distributions to intractable targets via state swapping across interpolating distributions. The effectiveness of PT is limited by the often minimal overlap between adjacent distributions in challenging problems, which requires increasing the computational resources to compensate. We introduce a framework that accelerates PT by leveraging neural samplers -- including normalising flows, diffusion models, and controlled diffusions -- to reduce the required overlap. Our approach utilises neural samplers in parallel, circumventing the computational burden of neural samplers while preserving the asymptotic consistency of classical PT. We demonstrate theoretically and empirically on a variety of multimodal sampling problems that our method improves sample quality, reduces the computational cost compared to classical PT, and enables efficient free energy/normalising constant estimation.

Accelerated Parallel Tempering via Neural Transports

Abstract

Markov Chain Monte Carlo (MCMC) algorithms are essential tools in computational statistics for sampling from unnormalised probability distributions, but can be fragile when targeting high-dimensional, multimodal, or complex target distributions. Parallel Tempering (PT) enhances MCMC's sample efficiency through annealing and parallel computation, propagating samples from tractable reference distributions to intractable targets via state swapping across interpolating distributions. The effectiveness of PT is limited by the often minimal overlap between adjacent distributions in challenging problems, which requires increasing the computational resources to compensate. We introduce a framework that accelerates PT by leveraging neural samplers -- including normalising flows, diffusion models, and controlled diffusions -- to reduce the required overlap. Our approach utilises neural samplers in parallel, circumventing the computational burden of neural samplers while preserving the asymptotic consistency of classical PT. We demonstrate theoretically and empirically on a variety of multimodal sampling problems that our method improves sample quality, reduces the computational cost compared to classical PT, and enables efficient free energy/normalising constant estimation.

Paper Structure

This paper contains 70 sections, 13 theorems, 129 equations, 7 figures, 4 tables, 2 algorithms.

Key Result

Theorem 1

The APT Markov chain $\textbf{X}_t$ generated by alg:APT is ergodic and $\boldsymbol{\pi}$-invariant. Moreover, the probability the $n$-th accelerated swap is rejected at stationarity equals $r({\mathbb{P}}^{n-1}_K,{\mathbb{Q}}^n_K)$.

Figures (7)

  • Figure 1: (Left) An illustration of the local exploration and communication step for PT vs APT. (Middle) 1,000 samples of a Gaussian mixture model target obtained using PT vs APT with a standard Gaussian reference. See Appendix \ref{['ex:comparison-methods']} for more details. (Right) Round trips for PT and APT with $N=6$ chains over $T=100,000$ iterations of Algorithm \ref{['alg:APT']}.
  • Figure 2: Round trip metrics for $K$-step Diff-APT ($K=1, 2, 5$) and Diff-PT using the true diffusion path, and PT targeting GMM-$d$ for $d=2, 10, 50, 100$ when using 30 chains. (Left) Round trip rate against $d$. (Right) Compute-normalised round trip rate against $d$.
  • Figure 3: Estimates of $\Delta F$ for DW4 and ManyWell-32 by PT, CMCD-APT ($K=1, 2, 5$) and Diff-APT ($K=0, 1, 2, 5$) using 1,000 samples. Each box consists of 30 estimates. The black dashed lines denote the reference constant $\Delta F\approx 29.660$ estimated with PT using 60 chains and 100,000 samples and $\Delta F \approx 164.696$ from midgley2022flow for ManyWell-32.
  • Figure 4: Interatomic distance $d_{ij}$ of 5,000 samples by CMCD, CMCD-APT, Diffusion, Diff-APT with 30 chains, $K=1, 2, 5$ on DW4. We take 100,000 samples by PT with 60 chains as ground truth.
  • Figure 5: Visualisation of ManyWell-32 samples generated by 1,000 consecutive CMCD-APT steps.
  • ...and 2 more figures

Theorems & Definitions (25)

  • Theorem 1
  • Proposition 1
  • Proposition 2
  • Proposition 3
  • Theorem 2
  • proof
  • proof
  • Lemma 1
  • proof
  • proof
  • ...and 15 more