Table of Contents
Fetching ...

Fast and Scalable Bayesian Deep Learning by Weight-Perturbation in Adam

Mohammad Emtiyaz Khan, Didrik Nielsen, Voot Tangkaratt, Wu Lin, Yarin Gal, Akash Srivastava

TL;DR

This work presents fast, scalable Bayesian deep learning by embedding weight perturbations into Adam to perform Gaussian mean-field variational inference. By leveraging approximate natural-gradient updates (VON, VOGN, Vprop) and natural-momentum (Vadam) or variational AdaGrad (VadaGrad), the approach achieves uncertainty estimates with reduced memory and computation compared to traditional VI methods. Empirical results show comparable uncertainty quality to state-of-the-art VI methods across logistic regression, neural networks, and reinforcement learning tasks, with clear benefits in exploration and early learning in RL. The framework offers a practical route to Bayesian deep learning that integrates with standard adaptive optimizers and supports weight perturbation as a mechanism for exploration and uncertainty propagation.

Abstract

Uncertainty computation in deep learning is essential to design robust and reliable systems. Variational inference (VI) is a promising approach for such computation, but requires more effort to implement and execute compared to maximum-likelihood methods. In this paper, we propose new natural-gradient algorithms to reduce such efforts for Gaussian mean-field VI. Our algorithms can be implemented within the Adam optimizer by perturbing the network weights during gradient evaluations, and uncertainty estimates can be cheaply obtained by using the vector that adapts the learning rate. This requires lower memory, computation, and implementation effort than existing VI methods, while obtaining uncertainty estimates of comparable quality. Our empirical results confirm this and further suggest that the weight-perturbation in our algorithm could be useful for exploration in reinforcement learning and stochastic optimization.

Fast and Scalable Bayesian Deep Learning by Weight-Perturbation in Adam

TL;DR

This work presents fast, scalable Bayesian deep learning by embedding weight perturbations into Adam to perform Gaussian mean-field variational inference. By leveraging approximate natural-gradient updates (VON, VOGN, Vprop) and natural-momentum (Vadam) or variational AdaGrad (VadaGrad), the approach achieves uncertainty estimates with reduced memory and computation compared to traditional VI methods. Empirical results show comparable uncertainty quality to state-of-the-art VI methods across logistic regression, neural networks, and reinforcement learning tasks, with clear benefits in exploration and early learning in RL. The framework offers a practical route to Bayesian deep learning that integrates with standard adaptive optimizers and supports weight perturbation as a mechanism for exploration and uncertainty propagation.

Abstract

Uncertainty computation in deep learning is essential to design robust and reliable systems. Variational inference (VI) is a promising approach for such computation, but requires more effort to implement and execute compared to maximum-likelihood methods. In this paper, we propose new natural-gradient algorithms to reduce such efforts for Gaussian mean-field VI. Our algorithms can be implemented within the Adam optimizer by perturbing the network weights during gradient evaluations, and uncertainty estimates can be cheaply obtained by using the vector that adapts the learning rate. This requires lower memory, computation, and implementation effort than existing VI methods, while obtaining uncertainty estimates of comparable quality. Our empirical results confirm this and further suggest that the weight-perturbation in our algorithm could be useful for exploration in reinforcement learning and stochastic optimization.

Paper Structure

This paper contains 40 sections, 2 theorems, 68 equations, 8 figures, 4 tables, 2 algorithms.

Key Result

Theorem 1

Denote the full-batch gradient with respect to $\theta_j$ by $g_j(\hbox{$\hbox{$\boldsymbol{\theta}$}$})$ and the corresponding full-batch GGN approximation by $h_j(\hbox{$\hbox{$\boldsymbol{\theta}$}$})$. Suppose minibatches $\mathcal{M}$ are sampled from the uniform distribution $p(\mathcal{M})$ o where $w = \frac{1}{M}(N-M)/(N-1)$.

Figures (8)

  • Figure 1: Comparison of Adam (left) and one of our proposed method Vadam (right). Adam performs maximum-likelihood estimation while Vadam performs variational inference, yet the two pseudocodes differ only slightly (differences highlighted in red). A major difference is in line 2 where, in Vadam, weights are perturbed during the gradient evaluations.
  • Figure 2: Experiments on Bayesian logistic regression showing (a) posterior approximations on a toy example, (b) performance on 'USPS-3v5' measuring negative ELBO, log-loss, and the symmetric KL divergence of the posterior approximation to MF-Exact, (c) symmetric KL divergence of Vadam for various minibatch sizes on 'Breast-Cancer' compared to VOGN with a minibatch of size 1.
  • Figure 3: The first 3 figures in the left show results on the Australian-Scale dataset using a neural network with a hidden layer of 64 units for different minibatch sizes $M$ and number of MC samples $S$. We see that VOGN converges the fastest, and Vadam too performs well for $M=1$. The rightmost figure shows results for exploration in deep RL where Vadam and VadaGrad outperform SGD-based methods.
  • Figure 4: The mean plus-minus one standard error of the Test RMSE (using 100 Monte Carlo samples) on the test sets of UCI experiments. The mean and standard errors are computed over the 20 data splits.
  • Figure 5: The early learning performance of Vadam, Adam-Plain and Adam-Explore, on the half-cheetah task in the reinforcement learning experiment. Vadam shows faster learning in this early stage of learning. The mean and standard error are computed over 5 trials.
  • ...and 3 more figures

Theorems & Definitions (2)

  • Theorem 1
  • Theorem 2