Table of Contents
Fetching ...

Preconditioning Hamiltonian Monte Carlo by minimizing Fisher Divergence

Adrian Seyboldt, Eliot L. Carlson, Bob Carpenter

Abstract

Although Hamiltonian Monte Carlo (HMC) scales as O(d^(1/4)) in dimension, there is a large constant factor determined by the curvature of the target density. This constant factor can be reduced in most cases through preconditioning, the state of the art for which uses diagonal or dense penalized maximum likelihood estimation of (co)variance based on a sample of warmup draws. These estimates converge slowly in the diagonal case and scale poorly when expanded to the dense case. We propose a more effective estimator based on minimizing the sample Fisher divergence from a linearly transformed density to a standard normal distribution. We present this estimator in three forms, (a) diagonal, (b) dense, and (c) low-rank plus diagonal. Using a collection of 114 models from posteriordb, we demonstrate that the diagonal minimizer of Fisher divergence outperforms the industry-standard variance-based diagonal estimators used by Stan and PyMC by a median factor of 1.3. The low-rank plus diagonal minimizer of the Fisher divergence outperforms Stan and PyMC's diagonal estimators by a median factor of 4.

Preconditioning Hamiltonian Monte Carlo by minimizing Fisher Divergence

Abstract

Although Hamiltonian Monte Carlo (HMC) scales as O(d^(1/4)) in dimension, there is a large constant factor determined by the curvature of the target density. This constant factor can be reduced in most cases through preconditioning, the state of the art for which uses diagonal or dense penalized maximum likelihood estimation of (co)variance based on a sample of warmup draws. These estimates converge slowly in the diagonal case and scale poorly when expanded to the dense case. We propose a more effective estimator based on minimizing the sample Fisher divergence from a linearly transformed density to a standard normal distribution. We present this estimator in three forms, (a) diagonal, (b) dense, and (c) low-rank plus diagonal. Using a collection of 114 models from posteriordb, we demonstrate that the diagonal minimizer of Fisher divergence outperforms the industry-standard variance-based diagonal estimators used by Stan and PyMC by a median factor of 1.3. The low-rank plus diagonal minimizer of the Fisher divergence outperforms Stan and PyMC's diagonal estimators by a median factor of 4.
Paper Structure (35 sections, 12 theorems, 90 equations, 7 figures, 1 table, 4 algorithms)

This paper contains 35 sections, 12 theorems, 90 equations, 7 figures, 1 table, 4 algorithms.

Key Result

Theorem 2.2

The estimated divergence is minimized when $\mu, \sigma^2 = \mu^*, {\sigma^2}^*$, where where $\mathop{\mathrm{diag}}\nolimits^{-1}:\mathbb{R}^{d \times d} \rightarrow \mathbb{R}^d$ extracts the diagonal of a matrix as a vector and $\#$ is the geometric mean operator in the affine-invariant Riemannian metric over symmetric, positive-definite matrices.

Figures (7)

  • Figure 1: Simulated log condition numbers $\kappa$ for Fisher- and covariance-based diagonal preconditioners. For each experiment, we generate 1000 random covariance spectra ($\Sigma = D^{1/2}U\mathop{\mathrm{diag}}\nolimits \left( \lambda^2 \right)U^TD^{1/2}$, $U\sim \text{Uniform}((0, 200))$, $\log(\lambda_i)\sim \normal{0, \sigma^2}$, $\log(D_{i,i})\sim \normal{0, 2^2}$) and simulate sets of [10, 20, 50] i.i.d. draws from $\normal{0, \Sigma}$. We then calculate the condition number $\kappa'$ of the adapted posterior with variance and Fisher-divergence based diagonal preconditioning.
  • Figure 2: Cumulative distribution plots for target densities in posteriordb. Each line represents a different sampler, nutpie (blue), Stan (orange), and nutpie low-rank (green) run for 1000 warmup iterations and 1000 posterior draws. The top row of plots show the raw diagnostic, while the bottom show the ratio of the diagnostic to Stan's. For instance, a point $(x,y)$ on the lower gradients plot says that a fraction $y$ of posteriordb models had $x$ or fewer times the number of gradient evaluations per effective sample as Stan's. Lines to the right of 1 show a small fraction of models for which Stan's default outperforms the other options.
  • Figure 3: Trace-plot of sample log density scaled by number of gradient evaluations for three models from posteriordb. Each row corresponds to one model, the rows show nutpie with diagonal mass matrix adaptation, nutpie with low rank adaptation and Stan, respectively. The x-axis shows the number of gradient evaluations rather than the number of draws as typical with a trace-plot. All samplers run 1000 warmup draws. nutpie uses significantly fewer gradient evaluations in total, and avoids a long inefficient phase seen in Stan. \ref{['fig:warmup-trace-random']} shows this trace-plot for 15 randomly selected models from posteriordb.
  • Figure 4: Example mass matrix adaptation scheme with background (below) and foreground (above) covariance estimators, with a switch/flush frequency of 10 draws. $\Sigma_i$ represents the covariance estimate used to sample draw $n+1$. Labels beneath them indicate the indices of the draw-gradient pairs on which that estimate is based. The transformation used for the Hamiltonian at each iteration is informed by the estimate stored in the foreground's state.
  • Figure 5: Cumulative distribution plots for target densities in posteriordb. Each line represents a different sampler, nutpie (blue), Stan (orange), and nutpie low-rank (green) run for 1000 warmup iterations and 1000 posterior draws. The top row of plots show the raw diagnostic, while the bottom show the ratio of the diagnostic to Stan's. In the bottom left plot, the points to the right of 1 represent a small fraction of models for which Stan's default outperforms the other options in wall time. Interestingly, for a sizable chunk models, nutpie lags behind Stan in minimum effective draws. However, this is more than made up for in the reduction in the number of required gradient evaluations.
  • ...and 2 more figures

Theorems & Definitions (27)

  • Definition 2.1: Fisher divergence
  • Theorem 2.2: Diagonal divergence minimization
  • proof
  • Theorem 2.3: Dense minimization
  • proof
  • Theorem 2.4
  • proof
  • Lemma 3.1: Scale free
  • proof
  • Definition A.1: AIRM
  • ...and 17 more