Scalable Stochastic Gradient Riemannian Langevin Dynamics in Non-Diagonal Metrics
Hanlin Yu, Marcelo Hartmann, Bernardo Williams, Arto Klami
TL;DR
The paper tackles efficient Bayesian inference for large neural networks by extending stochastic-gradient Langevin dynamics with non-diagonal Riemannian metrics. It introduces two scalable metrics, Monge and Shampoo, that encode curvature information without prohibitive cost: Monge uses a rank-one update to $I_D$, while Shampoo employs Kronecker-factor blocks to capture parameter correlations. Across MNIST, CIFAR-10, and funnel-like tests, Shampoo consistently improves log-likelihood and accuracy, with Monge providing gains in settings requiring careful tuning of its hyperparameter $oldsymbol{ abla l}$-driven metric and sometimes matching identity in easier posteriors. The results demonstrate practical, scalable curvature-informed sampling that outperforms traditional diagonal approaches, especially under heavier priors, and highlight considerations for stability and computational overhead. The work suggests that non-diagonal SGRLD metrics can be viable in real-world, high-dimensional Bayesian neural networks and may be integrated with broader advances in scalable probabilistic inference.
Abstract
Stochastic-gradient sampling methods are often used to perform Bayesian inference on neural networks. It has been observed that the methods in which notions of differential geometry are included tend to have better performances, with the Riemannian metric improving posterior exploration by accounting for the local curvature. However, the existing methods often resort to simple diagonal metrics to remain computationally efficient. This loses some of the gains. We propose two non-diagonal metrics that can be used in stochastic-gradient samplers to improve convergence and exploration but have only a minor computational overhead over diagonal metrics. We show that for fully connected neural networks (NNs) with sparsity-inducing priors and convolutional NNs with correlated priors, using these metrics can provide improvements. For some other choices the posterior is sufficiently easy also for the simpler metrics.
