Table of Contents
Fetching ...

Learning to Explore for Stochastic Gradient MCMC

SeungHyun Kim, Seohyeon Jung, Seonghyeon Kim, Juho Lee

TL;DR

This work targets Bayesian neural networks where posterior inference is hindered by high dimensionality and multimodality. It introduces Learning to Explore (L2E), a meta-learning SGMCMC framework that learns gradients of the kinetic energy via neural nets $\alpha_\phi$ and $\beta_\phi$, while keeping the diffusion and curl components simple and eliminating the costly $\Gamma(z)$ term. The meta-objective is based on the predictive distribution, optimized with unbiased gradient estimators, and the method is trained over a diverse, multitask task distribution to promote transfer to unseen datasets and architectures. Empirically, L2E achieves faster mixing, better predictive accuracy, improved multimodality capture, and robustness under distribution shifts on image benchmarks, with modest computational overhead compared to standard SG-MCMC baselines.

Abstract

Bayesian Neural Networks(BNNs) with high-dimensional parameters pose a challenge for posterior inference due to the multi-modality of the posterior distributions. Stochastic Gradient MCMC(SGMCMC) with cyclical learning rate scheduling is a promising solution, but it requires a large number of sampling steps to explore high-dimensional multi-modal posteriors, making it computationally expensive. In this paper, we propose a meta-learning strategy to build \gls{sgmcmc} which can efficiently explore the multi-modal target distributions. Our algorithm allows the learned SGMCMC to quickly explore the high-density region of the posterior landscape. Also, we show that this exploration property is transferrable to various tasks, even for the ones unseen during a meta-training stage. Using popular image classification benchmarks and a variety of downstream tasks, we demonstrate that our method significantly improves the sampling efficiency, achieving better performance than vanilla \gls{sgmcmc} without incurring significant computational overhead.

Learning to Explore for Stochastic Gradient MCMC

TL;DR

This work targets Bayesian neural networks where posterior inference is hindered by high dimensionality and multimodality. It introduces Learning to Explore (L2E), a meta-learning SGMCMC framework that learns gradients of the kinetic energy via neural nets and , while keeping the diffusion and curl components simple and eliminating the costly term. The meta-objective is based on the predictive distribution, optimized with unbiased gradient estimators, and the method is trained over a diverse, multitask task distribution to promote transfer to unseen datasets and architectures. Empirically, L2E achieves faster mixing, better predictive accuracy, improved multimodality capture, and robustness under distribution shifts on image benchmarks, with modest computational overhead compared to standard SG-MCMC baselines.

Abstract

Bayesian Neural Networks(BNNs) with high-dimensional parameters pose a challenge for posterior inference due to the multi-modality of the posterior distributions. Stochastic Gradient MCMC(SGMCMC) with cyclical learning rate scheduling is a promising solution, but it requires a large number of sampling steps to explore high-dimensional multi-modal posteriors, making it computationally expensive. In this paper, we propose a meta-learning strategy to build \gls{sgmcmc} which can efficiently explore the multi-modal target distributions. Our algorithm allows the learned SGMCMC to quickly explore the high-density region of the posterior landscape. Also, we show that this exploration property is transferrable to various tasks, even for the ones unseen during a meta-training stage. Using popular image classification benchmarks and a variety of downstream tasks, we demonstrate that our method significantly improves the sampling efficiency, achieving better performance than vanilla \gls{sgmcmc} without incurring significant computational overhead.
Paper Structure (64 sections, 15 equations, 11 figures, 15 tables, 2 algorithms)

This paper contains 64 sections, 15 equations, 11 figures, 15 tables, 2 algorithms.

Figures (11)

  • Figure 1: Predictive performance trend of each method as the number of samples for BMA increases. L2E exhibits superior predictive accuracy compared to other baseline methods. Note that only the fashion-MNIST dataset is included in meta-training task distribution. We also plot the performance of reference samples from izmailov2021bayesian as dashed line for CIFAR-10 and CIFAR-100. For meta-training details of Meta-, please refer to \ref{['app:sec:metasgmcmctraining']}.
  • Figure 2: Figures show multi-modality of various methods , , Meta- and with ResNet20-FRN on CIFAR-10. (a) shows cosine similarity between weights. (b) is loss surface as a function of model parameters in a 2-dimensional subspace spanned by solutions of each method. Colors represent the level of test error. (c) shows test error ($\%$) along linear path between a pair of parameters. Due to the inferior performance of Meta-, the offset of test error adjusted individually in CIFAR-10.
  • Figure 3: Plots of $||\Delta \theta||^2$ and train during training of ,, on CIFAR-10. Unlike other methods, actively updates $\theta$ in the local minima while maintaining training as nearly constant.
  • Figure 4: Countour plots of absolute value of outputs of $\beta_\phi$(top) and $\alpha_\phi$(bottom) on the grid. $\beta_{\phi}$ produces large magnitude of output when $\nabla_{\theta}\tilde{U}(\theta)$ is high. When $\nabla_{\theta}\tilde{U}(\theta)$ gets smaller, the overall magnitude decreases as expected, but even when $\nabla_{\theta}\tilde{U}(\theta)$ is nearly zero, $\beta_\phi$ can still allow the sampler to move around posterior distribution when integrated with high momentum value. The regions marked with red dashed boxes can be beneficial for exploration in high density regions. $\alpha_\phi$ is proportional to $\nabla_{\theta}\tilde{U}(\theta)$ in general, which helps the sampler fastly converge to the high density region.
  • Figure 5: Plots of $||\Delta \theta||^2$ and train during training of ,, on CIFAR-100.
  • ...and 6 more figures