Table of Contents
Fetching ...

Variational Learning Finds Flatter Solutions at the Edge of Stability

Avrajit Ghosh, Bai Cong, Rio Yokota, Saiprasad Ravishankar, Rongrong Wang, Molei Tao, Mohammad Emtiyaz Khan, Thomas Möllenhoff

TL;DR

This paper investigates how Variational Learning (VL) induces implicit regularization via Edge of Stability (EoS) dynamics. By analyzing a quadratic VL problem, it derives a stability threshold involving the Variational Factor $\mathrm{VF}(z)$ that scales the GD bound by $\tfrac{2}{\rho} \mathrm{VF}(z)$, and shows that VL can achieve flatter regions than standard GD by tuning the posterior covariance $\boldsymbol{\Sigma}$ and the number of samples $N_s$. The authors extend the analysis to deep networks and validate the theory with extensive experiments on MLPs, ResNets, and Vision Transformers, including heavy-tailed and adaptive posterior methods (VGD, IVON). The results indicate that VL systematically reduces sharpness and improves generalization, and that the stability boundary closely tracks the VF predictions across architectures and tasks. Overall, the work provides a dynamical explanation for VL's empirical success and suggests principled ways to control posterior shape and sampling to achieve flatter, more generalizable solutions.

Abstract

Variational Learning (VL) has recently gained popularity for training deep neural networks. Part of its empirical success can be explained by theories such as PAC-Bayes bounds, minimum description length and marginal likelihood, but little has been done to unravel the implicit regularization in play. Here, we analyze the implicit regularization of VL through the Edge of Stability (EoS) framework. EoS has previously been used to show that gradient descent can find flat solutions and we extend this result to show that VL can find even flatter solutions. This result is obtained by controlling the shape of the variational posterior as well as the number of posterior samples used during training. The derivation follows in a similar fashion as in the standard EoS literature for deep learning, by first deriving a result for a quadratic problem and then extending it to deep neural networks. We empirically validate these findings on a wide variety of large networks, such as ResNet and ViT, to find that the theoretical results closely match the empirical ones. Ours is the first work to analyze the EoS dynamics of VL.

Variational Learning Finds Flatter Solutions at the Edge of Stability

TL;DR

This paper investigates how Variational Learning (VL) induces implicit regularization via Edge of Stability (EoS) dynamics. By analyzing a quadratic VL problem, it derives a stability threshold involving the Variational Factor that scales the GD bound by , and shows that VL can achieve flatter regions than standard GD by tuning the posterior covariance and the number of samples . The authors extend the analysis to deep networks and validate the theory with extensive experiments on MLPs, ResNets, and Vision Transformers, including heavy-tailed and adaptive posterior methods (VGD, IVON). The results indicate that VL systematically reduces sharpness and improves generalization, and that the stability boundary closely tracks the VF predictions across architectures and tasks. Overall, the work provides a dynamical explanation for VL's empirical success and suggests principled ways to control posterior shape and sampling to achieve flatter, more generalizable solutions.

Abstract

Variational Learning (VL) has recently gained popularity for training deep neural networks. Part of its empirical success can be explained by theories such as PAC-Bayes bounds, minimum description length and marginal likelihood, but little has been done to unravel the implicit regularization in play. Here, we analyze the implicit regularization of VL through the Edge of Stability (EoS) framework. EoS has previously been used to show that gradient descent can find flat solutions and we extend this result to show that VL can find even flatter solutions. This result is obtained by controlling the shape of the variational posterior as well as the number of posterior samples used during training. The derivation follows in a similar fashion as in the standard EoS literature for deep learning, by first deriving a result for a quadratic problem and then extending it to deep neural networks. We empirically validate these findings on a wide variety of large networks, such as ResNet and ViT, to find that the theoretical results closely match the empirical ones. Ours is the first work to analyze the EoS dynamics of VL.

Paper Structure

This paper contains 18 sections, 6 theorems, 69 equations, 29 figures.

Key Result

Lemma 2.1

(Descent Lemma) For a GD update $\boldsymbol{\theta}_{t+1} = \boldsymbol{\theta}_t - \rho \nabla \ell(\boldsymbol{\theta}_t)$ on the quadratic loss quad-loss, the loss decreases at each step, that is, we have

Figures (29)

  • Figure 1: Panel (a): The left figure shows trajectory traces of VL on a quadratic problem with an isotropic variational posterior whose mean is learned but variance is set to a fixed value. The trajectory becomes more unstable as the posterior variance is increased and number of Monte-Carlo samples is decreased. We provide an exact expression to compute the stability threshold at which the iterations become unstable (Theorem \ref{['main:proof-thm1']}). Panel (b): We show the validity of the threshold on neural network training. The right figure (top) shows this on CIFAR-10 for an MLP where VL achieves lower sharpness than GD when posterior variance is increased. The bottom figure shows that the sharpness (solid line) matches the stability threshold obtained by our theorem (dashed line).
  • Figure 2: VL's mechanism for flatter minima: The posterior variance determines the minima's location. A small variance settles the posterior in a sharp minima (left), while a larger variance allows it to explore and find a flat minima (right).
  • Figure 3: (a) Solid black curve shows the theoretical stability-threshold of VGD as a function of $N_s / \sigma^2$. The curve is clearly lower than the stability threshold of GD, shown with the horizontal dashed, gray line. (b) Empirical verification on a scalar quadratic problem with curvature $\lambda$ where we plot the empirically computed probability of descent for VGD runs with different values of $\sigma^2$. We show a heatmap for $(\lambda, 1/\sigma^2)$ values where lighter colors indicate higher probability of descent. We overlay the heatmap with the stability thresholds of VGD (red solid curve), clearly showing that theoretical limit shown in \ref{['eq:vgd_st']} matches the empirical probability. (c) The figure further includes $N_s$ and marks the region where a pair $(N_s, \sigma^2)$ will either lead to descent or not (marked with 'stable' and 'unstable' respectively).
  • Figure 4: Smaller sharpness corresponds to higher test accuracy for network architectures trained on CIFAR-10. Panels show (a) ViT, (b) MLP, and (c) ResNet-20. Full batch GD with Variational GD achieves lower sharpness and better test accuracy.
  • Figure 5: Normalized Sharpness $\|\nabla^2 \ell(\mathbf{m}_{t}) \|_{2}/(2/\rho)$ hovers around the Variational Factor in MLP.
  • ...and 24 more figures

Theorems & Definitions (10)

  • Lemma 2.1
  • Theorem 3.1
  • Lemma 3.1
  • Definition 3.2: Asymptotic Stability lyapunov1992general, Chapter 2
  • Definition 3.3: Stochastic Stability kushner2006stochastic, Chapter 2
  • Theorem B.1
  • proof
  • Lemma B.0
  • proof
  • Theorem D.1: Concentration of Smoothed Curvature