Table of Contents
Fetching ...

Scalable Stochastic Gradient Riemannian Langevin Dynamics in Non-Diagonal Metrics

Hanlin Yu, Marcelo Hartmann, Bernardo Williams, Arto Klami

TL;DR

The paper tackles efficient Bayesian inference for large neural networks by extending stochastic-gradient Langevin dynamics with non-diagonal Riemannian metrics. It introduces two scalable metrics, Monge and Shampoo, that encode curvature information without prohibitive cost: Monge uses a rank-one update to $I_D$, while Shampoo employs Kronecker-factor blocks to capture parameter correlations. Across MNIST, CIFAR-10, and funnel-like tests, Shampoo consistently improves log-likelihood and accuracy, with Monge providing gains in settings requiring careful tuning of its hyperparameter $oldsymbol{ abla l}$-driven metric and sometimes matching identity in easier posteriors. The results demonstrate practical, scalable curvature-informed sampling that outperforms traditional diagonal approaches, especially under heavier priors, and highlight considerations for stability and computational overhead. The work suggests that non-diagonal SGRLD metrics can be viable in real-world, high-dimensional Bayesian neural networks and may be integrated with broader advances in scalable probabilistic inference.

Abstract

Stochastic-gradient sampling methods are often used to perform Bayesian inference on neural networks. It has been observed that the methods in which notions of differential geometry are included tend to have better performances, with the Riemannian metric improving posterior exploration by accounting for the local curvature. However, the existing methods often resort to simple diagonal metrics to remain computationally efficient. This loses some of the gains. We propose two non-diagonal metrics that can be used in stochastic-gradient samplers to improve convergence and exploration but have only a minor computational overhead over diagonal metrics. We show that for fully connected neural networks (NNs) with sparsity-inducing priors and convolutional NNs with correlated priors, using these metrics can provide improvements. For some other choices the posterior is sufficiently easy also for the simpler metrics.

Scalable Stochastic Gradient Riemannian Langevin Dynamics in Non-Diagonal Metrics

TL;DR

The paper tackles efficient Bayesian inference for large neural networks by extending stochastic-gradient Langevin dynamics with non-diagonal Riemannian metrics. It introduces two scalable metrics, Monge and Shampoo, that encode curvature information without prohibitive cost: Monge uses a rank-one update to , while Shampoo employs Kronecker-factor blocks to capture parameter correlations. Across MNIST, CIFAR-10, and funnel-like tests, Shampoo consistently improves log-likelihood and accuracy, with Monge providing gains in settings requiring careful tuning of its hyperparameter -driven metric and sometimes matching identity in easier posteriors. The results demonstrate practical, scalable curvature-informed sampling that outperforms traditional diagonal approaches, especially under heavier priors, and highlight considerations for stability and computational overhead. The work suggests that non-diagonal SGRLD metrics can be viable in real-world, high-dimensional Bayesian neural networks and may be integrated with broader advances in scalable probabilistic inference.

Abstract

Stochastic-gradient sampling methods are often used to perform Bayesian inference on neural networks. It has been observed that the methods in which notions of differential geometry are included tend to have better performances, with the Riemannian metric improving posterior exploration by accounting for the local curvature. However, the existing methods often resort to simple diagonal metrics to remain computationally efficient. This loses some of the gains. We propose two non-diagonal metrics that can be used in stochastic-gradient samplers to improve convergence and exploration but have only a minor computational overhead over diagonal metrics. We show that for fully connected neural networks (NNs) with sparsity-inducing priors and convolutional NNs with correlated priors, using these metrics can provide improvements. For some other choices the posterior is sufficiently easy also for the simpler metrics.
Paper Structure (33 sections, 4 theorems, 54 equations, 4 figures, 10 tables)

This paper contains 33 sections, 4 theorems, 54 equations, 4 figures, 10 tables.

Key Result

Theorem 2.1

Denote operator norm as $\Vert\cdot\Vert$. With assumption on smoothness and boundedness as provided in the Appendix app:proof, after ignoring the $\Gamma(\mathop{\mathrm{\boldsymbol{\theta}}}\nolimits)$ terms during discretized updates, for any quantity $\phi$ estimated as the empirical expectation for some constant $C > 0$ independent of $\{h_t\}$, where $\Delta V_{t}$ is an operator that is def

Figures (4)

  • Figure 1: Illustration of how the Monge metric captures the local curvature of the target density, here the funnel distribution (see Section \ref{['sec:funnel']}). The two plots illustrate the local metric around two separate parameter settings, in form of geodesic paths (red arrows) for different initial velocities sampled from a Euclidean ball and the surfaces of the final positions (solid red line). The metric helps in reaching the narrow funnel in y-direction (left) and is similar to Euclidean metric in the flat areas (right).
  • Figure 2: Funnel distribution in four metrics. The better metrics explore the funnel clearly better, though even the best one has difficulties reaching the very end. The right subplots show the challenging marginal of the 2D distribution (left subplot), indicating the true marginal with blue line and the samples with yellow histograms.
  • Figure 3: Log-probability (top) and accuracy (bottom) of different samplers on MNIST with hidden layer size $400$ and horseshoe prior. Shaded regions show $\pm 1.96$ standard deviations computed over $10$ replicates.
  • Figure 4: Log-probability (left) and accuracy (right) in Monge metric with varying $\alpha^{2}$ on MNIST with hidden layer size $400$ and horseshoe prior, in comparison to other metrics. Shaded areas show $\pm 1.96$ standard deviations computed over $10$ replicates.

Theorems & Definitions (4)

  • Theorem 2.1
  • Theorem 3.1
  • Theorem 3.2
  • Theorem A.2