Variational Learning Finds Flatter Solutions at the Edge of Stability
Avrajit Ghosh, Bai Cong, Rio Yokota, Saiprasad Ravishankar, Rongrong Wang, Molei Tao, Mohammad Emtiyaz Khan, Thomas Möllenhoff
TL;DR
This paper investigates how Variational Learning (VL) induces implicit regularization via Edge of Stability (EoS) dynamics. By analyzing a quadratic VL problem, it derives a stability threshold involving the Variational Factor $\mathrm{VF}(z)$ that scales the GD bound by $\tfrac{2}{\rho} \mathrm{VF}(z)$, and shows that VL can achieve flatter regions than standard GD by tuning the posterior covariance $\boldsymbol{\Sigma}$ and the number of samples $N_s$. The authors extend the analysis to deep networks and validate the theory with extensive experiments on MLPs, ResNets, and Vision Transformers, including heavy-tailed and adaptive posterior methods (VGD, IVON). The results indicate that VL systematically reduces sharpness and improves generalization, and that the stability boundary closely tracks the VF predictions across architectures and tasks. Overall, the work provides a dynamical explanation for VL's empirical success and suggests principled ways to control posterior shape and sampling to achieve flatter, more generalizable solutions.
Abstract
Variational Learning (VL) has recently gained popularity for training deep neural networks. Part of its empirical success can be explained by theories such as PAC-Bayes bounds, minimum description length and marginal likelihood, but little has been done to unravel the implicit regularization in play. Here, we analyze the implicit regularization of VL through the Edge of Stability (EoS) framework. EoS has previously been used to show that gradient descent can find flat solutions and we extend this result to show that VL can find even flatter solutions. This result is obtained by controlling the shape of the variational posterior as well as the number of posterior samples used during training. The derivation follows in a similar fashion as in the standard EoS literature for deep learning, by first deriving a result for a quadratic problem and then extending it to deep neural networks. We empirically validate these findings on a wide variety of large networks, such as ResNet and ViT, to find that the theoretical results closely match the empirical ones. Ours is the first work to analyze the EoS dynamics of VL.
