Table of Contents
Fetching ...

A theoretical perspective on mode collapse in variational inference

Roman Soletskyi, Marylou Gabrié, Bruno Loureiro

TL;DR

This work carries a theoretical investigation of mode collapse for the gradient flow on Gaussian mixture models, and identifies the key low-dimensional statistics characterizing the flow, and derives a closed set of low-dimensional equations governing their evolution.

Abstract

While deep learning has expanded the possibilities for highly expressive variational families, the practical benefits of these tools for variational inference (VI) are often limited by the minimization of the traditional Kullback-Leibler objective, which can yield suboptimal solutions. A major challenge in this context is \emph{mode collapse}: the phenomenon where a model concentrates on a few modes of the target distribution during training, despite being statistically capable of expressing them all. In this work, we carry a theoretical investigation of mode collapse for the gradient flow on Gaussian mixture models. We identify the key low-dimensional statistics characterizing the flow, and derive a closed set of low-dimensional equations governing their evolution. Leveraging this compact description, we show that mode collapse is present even in statistically favorable scenarios, and identify two key mechanisms driving it: mean alignment and vanishing weight. Our theoretical findings are consistent with the implementation of VI using normalizing flows, a class of popular generative models, thereby offering practical insights.

A theoretical perspective on mode collapse in variational inference

TL;DR

This work carries a theoretical investigation of mode collapse for the gradient flow on Gaussian mixture models, and identifies the key low-dimensional statistics characterizing the flow, and derives a closed set of low-dimensional equations governing their evolution.

Abstract

While deep learning has expanded the possibilities for highly expressive variational families, the practical benefits of these tools for variational inference (VI) are often limited by the minimization of the traditional Kullback-Leibler objective, which can yield suboptimal solutions. A major challenge in this context is \emph{mode collapse}: the phenomenon where a model concentrates on a few modes of the target distribution during training, despite being statistically capable of expressing them all. In this work, we carry a theoretical investigation of mode collapse for the gradient flow on Gaussian mixture models. We identify the key low-dimensional statistics characterizing the flow, and derive a closed set of low-dimensional equations governing their evolution. Leveraging this compact description, we show that mode collapse is present even in statistically favorable scenarios, and identify two key mechanisms driving it: mean alignment and vanishing weight. Our theoretical findings are consistent with the implementation of VI using normalizing flows, a class of popular generative models, thereby offering practical insights.

Paper Structure

This paper contains 31 sections, 78 equations, 4 figures.

Figures (4)

  • Figure 1: Evolution of gradient descent dynamics from initialization $T=0$ to convergence $T=1$ for VI on a 2-Gaussian mixture target distribution $p(x)$ as in \ref{['eq:pstar2d']} with $||\mu_\star||_2 = 2.5$ for the first column, $||\mu_\star||_2 = 3.1$ for the remaining ones, and $w_\star = 1/3$ for all the columns. The mode position is marked by a black cross, and the 9th decile by gray lines. (First and second) $q_{\theta}$ depicted in blue is parameterized by a normalizing flow, see \ref{['sec:app:normalizing']}. (Third) $q_{\theta}(x)=w_{1}\mathcal{N}(\mu_{1},I_{2})+w_{2}\mathcal{N}(\mu_{2},I_{2})$ where we optimize over both means and weights $\theta=(w_{1},w_{2},\mu_{1},\mu_{2})$. (Fourth) $q_{\theta}(x)=w_\star\mathcal{N}(\mu_{1},I_{2})+(1-w_\star)\mathcal{N}(\mu_{2},I_{2})$ where we optimize only over the means $\theta=(\mu_{1},\mu_{2})$. In the last two columns the density the different components of $q_\theta$ are depicted in different colors and their opacity is proportional to their weight.
  • Figure 2: Basin of attraction of the mean alignment fixed points $m_{1}=m_{2}=\pm 1$ on the $s=0$ cross section of phase space $(m_{1},m_{2},s)\in[-1,1]^{3}$ for $R\in\{1, 1.25, 2, 3\}$, $d = 1000$, $w_{\star} =2/3$, and $\eta=0.05$. The dashed black line denote the boundary of the basin of attraction of the mode collapse fixed points. Solid lines denote individual flow trajectories with different random initialization $\mu_{1},\mu_{2}\sim{\rm Unif}(\mathbb{S}^{d-1}(R))$. Mode collapsed trajectories are red, avoiding it are blue.
  • Figure 3: Numerical evolution of sufficient statistics and weights for different radii $R\in\{1,1.9, 2.5\}$. Quasi-mode collapse happens in the middle. Learning rate $\eta=0.05$, batch size $B=1000$, $d=10$ and $w_{\star}=2/3$.
  • Figure 4: Dependence of threshold radius $R_c$ on dimension $d$ for different setups