Table of Contents
Fetching ...

Variational Inference Failures Under Model Symmetries: Permutation Invariant Posteriors for Bayesian Neural Networks

Yoav Gelberg, Tycho F. A. van der Ouderaa, Mark van der Wilk, Yarin Gal

TL;DR

This paper addresses how weight-space permutation symmetries in Bayesian neural networks induce multimodal posteriors that trip up standard variational inference, which typically uses unimodal approximations. The authors introduce a permutation-invariant variational posterior via G-symmetrization and derive a tractable ELBO that accounts for symmetry through a mutual-information correction estimated with an InfoNCE-based bound. They prove that symmetrized posteriors strictly improve posterior fit and demonstrate, through tractable-BNN and MNIST experiments, that VI with symmetrization yields higher ELBOs and better predictive accuracy, especially as model width increases. The approach is architecture-agnostic and provides a practical route to mitigating symmetry-induced biases in VI for Bayesian neural networks.

Abstract

Weight space symmetries in neural network architectures, such as permutation symmetries in MLPs, give rise to Bayesian neural network (BNN) posteriors with many equivalent modes. This multimodality poses a challenge for variational inference (VI) techniques, which typically rely on approximating the posterior with a unimodal distribution. In this work, we investigate the impact of weight space permutation symmetries on VI. We demonstrate, both theoretically and empirically, that these symmetries lead to biases in the approximate posterior, which degrade predictive performance and posterior fit if not explicitly accounted for. To mitigate this behavior, we leverage the symmetric structure of the posterior and devise a symmetrization mechanism for constructing permutation invariant variational posteriors. We show that the symmetrized distribution has a strictly better fit to the true posterior, and that it can be trained using the original ELBO objective with a modified KL regularization term. We demonstrate experimentally that our approach mitigates the aforementioned biases and results in improved predictions and a higher ELBO.

Variational Inference Failures Under Model Symmetries: Permutation Invariant Posteriors for Bayesian Neural Networks

TL;DR

This paper addresses how weight-space permutation symmetries in Bayesian neural networks induce multimodal posteriors that trip up standard variational inference, which typically uses unimodal approximations. The authors introduce a permutation-invariant variational posterior via G-symmetrization and derive a tractable ELBO that accounts for symmetry through a mutual-information correction estimated with an InfoNCE-based bound. They prove that symmetrized posteriors strictly improve posterior fit and demonstrate, through tractable-BNN and MNIST experiments, that VI with symmetrization yields higher ELBOs and better predictive accuracy, especially as model width increases. The approach is architecture-agnostic and provides a practical route to mitigating symmetry-induced biases in VI for Bayesian neural networks.

Abstract

Weight space symmetries in neural network architectures, such as permutation symmetries in MLPs, give rise to Bayesian neural network (BNN) posteriors with many equivalent modes. This multimodality poses a challenge for variational inference (VI) techniques, which typically rely on approximating the posterior with a unimodal distribution. In this work, we investigate the impact of weight space permutation symmetries on VI. We demonstrate, both theoretically and empirically, that these symmetries lead to biases in the approximate posterior, which degrade predictive performance and posterior fit if not explicitly accounted for. To mitigate this behavior, we leverage the symmetric structure of the posterior and devise a symmetrization mechanism for constructing permutation invariant variational posteriors. We show that the symmetrized distribution has a strictly better fit to the true posterior, and that it can be trained using the original ELBO objective with a modified KL regularization term. We demonstrate experimentally that our approach mitigates the aforementioned biases and results in improved predictions and a higher ELBO.
Paper Structure (14 sections, 7 theorems, 42 equations, 5 figures, 3 tables)

This paper contains 14 sections, 7 theorems, 42 equations, 5 figures, 3 tables.

Key Result

Proposition 2.2

$p(\boldsymbol{\omega} \mid \mathcal{D})$ is $G$-invariant.

Figures (5)

  • Figure 1: Symmetrization of the variational posterior. Given a variational distribution $q_\theta(\boldsymbol{\omega})$, its symmetrization with respect to a group $G$ acting on the underlying space is the average of the pushforwards of $q_\theta(\boldsymbol{\omega})$ over all group elements $g \in G$.
  • Figure 2: Unimodal reverse KL minimizer of a bimodal target distribution. Plots of the target distribution $q_\alpha(x)$ and the unimodal reverse KL minimizer $q_{\boldsymbol{\mu}, \boldsymbol{\Sigma}}(x)$ in one dimension.
  • Figure 3: behavior of the unimodal reverse KL minimizer. We optimize $(\boldsymbol{\mu}^*, \boldsymbol{\Sigma}^*)$ to minimize $\text{KL}(q_{\boldsymbol{\mu}, \boldsymbol{\Sigma}}(\mathbf{x}) \mid \mid p_\alpha(\mathbf{x}))$, where $q_{\boldsymbol{\mu}, \boldsymbol{\Sigma}}(\mathbf{x}) = \mathcal{N}(\mathbf{x}; \boldsymbol{\mu}, \boldsymbol{\Sigma})$ is a Gaussian distribution and $p_\alpha(\mathbf{x})$ is a mixture of two standard Gaussians with one component centered at $\mathbf{0}$ and another centered $\alpha$ away from $\mathbf{0}$. We plot the results of the optimization as a function of $\alpha$. From left to right, we plot: $\alpha^{-1}\lVert \boldsymbol{\mu}^* \rVert_2$ (quantifying the amount of interpolation between $p_\alpha(\mathbf{x})$'s components), $\det(\boldsymbol{\Sigma}^*)$, and the optimal $\hat{\text{KL}}$ reached by optimization.
  • Figure 4: Posterior plots for a tractable BNN. Ground truth and approximate posterior plots for the model $\mathbf{f}^{w_1, w_2}(x) = \text{ReLU}(w_1 x) + \text{ReLU}(w_2 x)$ given data generated by $f_\alpha(x) = \alpha \lvert x \rvert =\text{ReLU}(\alpha x) + \text{ReLU}(-\alpha x)$.
  • Figure 5: Threshold for mode-seeking behavior. We plot $\alpha^{-1}\lVert \boldsymbol{\mu}^* \rVert_2$ (interpreted as quantifying the amount of interpolation between $p_\alpha^\sigma(\mathbf{x})$'s modes) and the optimal $\hat{\text{KL}}$ reached by optimization. Plots are provided for $\sigma = 0.25, 0.5, 1, 2$. As can be seen from the plots, the corresponding thresholds $\alpha^* \approx 1.25, 2.5, 5, 10$, scale linearly with $\sigma$.

Theorems & Definitions (13)

  • Definition 2.1: MLP permutation symmetry group
  • Proposition 2.2
  • Theorem 3.1
  • Theorem 4.1
  • Corollary 4.2
  • proof : Proof of Corollary \ref{['cor:posterior_fit']}
  • Theorem 4.3
  • proof : Proof of Theorem \ref{['thm:mode_proximity']}
  • proof : Proof of Theorem \ref{['thm:elbo_correction']}
  • Lemma 4.1
  • ...and 3 more