Table of Contents
Fetching ...

Annealing in variational inference mitigates mode collapse: A theoretical study on Gaussian mixtures

Luigi Fogliani, Bruno Loureiro, Marylou Gabrié

TL;DR

This work provides a mathematical analysis of annealing based strategies for mitigating mode collapse in a tractable setting: learning a Gaussian mixture, where mode collapse is known to arise.

Abstract

Mode collapse, the failure to capture one or more modes when targetting a multimodal distribution, is a central challenge in modern variational inference. In this work, we provide a mathematical analysis of annealing based strategies for mitigating mode collapse in a tractable setting: learning a Gaussian mixture, where mode collapse is known to arise. Leveraging a low dimensional summary statistics description, we precisely characterize the interplay between the initial temperature and the annealing rate, and derive a sharp formula for the probability of mode collapse. Our analysis shows that an appropriately chosen annealing scheme can robustly prevent mode collapse. Finally, we present numerical evidence that these theoretical tradeoffs qualitatively extend to neural network based models, RealNVP normalizing flows, providing guidance for designing annealing strategies mitigating mode collapse in practical variational inference pipelines.

Annealing in variational inference mitigates mode collapse: A theoretical study on Gaussian mixtures

TL;DR

This work provides a mathematical analysis of annealing based strategies for mitigating mode collapse in a tractable setting: learning a Gaussian mixture, where mode collapse is known to arise.

Abstract

Mode collapse, the failure to capture one or more modes when targetting a multimodal distribution, is a central challenge in modern variational inference. In this work, we provide a mathematical analysis of annealing based strategies for mitigating mode collapse in a tractable setting: learning a Gaussian mixture, where mode collapse is known to arise. Leveraging a low dimensional summary statistics description, we precisely characterize the interplay between the initial temperature and the annealing rate, and derive a sharp formula for the probability of mode collapse. Our analysis shows that an appropriately chosen annealing scheme can robustly prevent mode collapse. Finally, we present numerical evidence that these theoretical tradeoffs qualitatively extend to neural network based models, RealNVP normalizing flows, providing guidance for designing annealing strategies mitigating mode collapse in practical variational inference pipelines.
Paper Structure (26 sections, 58 equations, 7 figures)

This paper contains 26 sections, 58 equations, 7 figures.

Figures (7)

  • Figure 1: Preliminary experiments: annealing mitigates mode collapse. The first 4 columns show the marginal density along the direction of $\mu_*$ of both $\pi$ and $q_\theta$, at 4 different training stages. The fifth column shows the variance dynamics. Each row represents a different scenario; from top to bottom: no annealing with $\sigma_{1,2}$ initialized to one, no annealing with covering initial distribution (high initial variances $\sigma_{1,2} \gg R$), annealing with schedule given by \ref{['exponential annealing schedule']}. Hyperparameters values are: $d=512, R=3, w_*=0.8, w_1=0.5$ and learning rate $\eta=0.05$.
  • Figure 2: Annealed VI mode collapse probability with Gaussian mixture student. The form of the annealing scheme is fixed by \ref{['exponential annealing schedule']} and different values of the initial temperature $\beta_i$ and annealing time $t_0$ are explored. Hyperparamters are the same as in \ref{['fig:preliminary experiments']}. For each values of $(\beta_i, t_0)$, the mode collapse probability is obtained using 100 different random student mean initialization. The dashed black line is the 0.5 mode collapse isoprobability line obtained with the analytical estimate of \ref{['eq: prob collapse']}.
  • Figure 3: Dynamics of the summary statistics. Coloured curves correspond to SGD trajectories on both means and variances of the experiment corresponding to the third row of \ref{['fig:preliminary experiments']}. Black curves are the numerical integration of the \ref{['eq: dynamical system annealing']}, that assume perfect gradient estimates and quasi static approximation for the variances $\sigma_{1,2}^2 = \beta(t)^{-1}$. Hyperparameters are $d=512$, $R=3$, and $w_*=0.8$.
  • Figure 4: Annealed VI mode collapse probability with RealNVP student. Mode collapse probabilities are computed on 200 different runs for each $\beta_i$ and annealing time $t_0$. We employ an exponential annealing schedule (\ref{['exponential annealing schedule']}) on a bimodal Gaussian mixture target with hyperparamters $d=128$, $R=3$ and $w_*=0.8$. The generative model is a RealNVP with 8 coupling layers, each containing two MLPs of depth 4 and hidden dimension $4d=512$.
  • Figure 5: Probability of mode collapse as a function of the annealing rate $\beta_i^{-1/t_0}$, for an exponential annealing schedule given by \ref{['exponential annealing schedule']}. For the Gaussian mixture student the data are the same as \ref{['fig:bimodal grid collapse prob']}, hyperparameters are $d=512, R=3, w_*=0.8$. The dashed black line is the high initial temperature ($\beta_i \to 0$) asymptote of the mode collapse probability \ref{['eq: prob collapse']} ($\alpha = 0.608$). For the realNVP student the data are the same as \ref{['fig: realnvp grid prob collapse']}, hyperparameters are $d=128, R=3, w_*=0.8$.
  • ...and 2 more figures

Theorems & Definitions (1)

  • Remark 1