Table of Contents
Fetching ...

Learning Mixture Density via Natural Gradient Expectation Maximization

Yutao Chen, Jasmine Bayrooti, Steven Morad

TL;DR

This work tackles slow convergence and mode collapse in training mixture density networks (MDNs). By reframing MDNs as latent-variable models and uniting expectation-maximization with natural gradient geometry, the authors derive natural gradient EM (ngem), a tractable objective that preconditions updates with a block-diagonal complete-data Fisher information. Empirically, ngem yields up to 10x faster convergence, better mode separation, and robust performance on high-dimensional tasks like inverse MNIST, while incurring negligible overhead and complementary gains with curvature-aware optimizers. The approach relies on diagonal Gaussian and categorical FIMs to enable component-wise, efficient updates, offering a practical enhancement for uncertainty-aware, multimodal regression with MDNs.

Abstract

Mixture density networks are neural networks that produce Gaussian mixtures to represent continuous multimodal conditional densities. Standard training procedures involve maximum likelihood estimation using the negative log-likelihood (NLL) objective, which suffers from slow convergence and mode collapse. In this work, we improve the optimization of mixture density networks by integrating their information geometry. Specifically, we interpret mixture density networks as deep latent-variable models and analyze them through an expectation maximization framework, which reveals surprising theoretical connections to natural gradient descent. We then exploit such connections to derive the natural gradient expectation maximization (nGEM) objective. We show that empirically nGEM achieves up to 10$\times$ faster convergence while adding almost zerocomputational overhead, and scales well to high-dimensional data where NLL otherwise fails.

Learning Mixture Density via Natural Gradient Expectation Maximization

TL;DR

This work tackles slow convergence and mode collapse in training mixture density networks (MDNs). By reframing MDNs as latent-variable models and uniting expectation-maximization with natural gradient geometry, the authors derive natural gradient EM (ngem), a tractable objective that preconditions updates with a block-diagonal complete-data Fisher information. Empirically, ngem yields up to 10x faster convergence, better mode separation, and robust performance on high-dimensional tasks like inverse MNIST, while incurring negligible overhead and complementary gains with curvature-aware optimizers. The approach relies on diagonal Gaussian and categorical FIMs to enable component-wise, efficient updates, offering a practical enhancement for uncertainty-aware, multimodal regression with MDNs.

Abstract

Mixture density networks are neural networks that produce Gaussian mixtures to represent continuous multimodal conditional densities. Standard training procedures involve maximum likelihood estimation using the negative log-likelihood (NLL) objective, which suffers from slow convergence and mode collapse. In this work, we improve the optimization of mixture density networks by integrating their information geometry. Specifically, we interpret mixture density networks as deep latent-variable models and analyze them through an expectation maximization framework, which reveals surprising theoretical connections to natural gradient descent. We then exploit such connections to derive the natural gradient expectation maximization (nGEM) objective. We show that empirically nGEM achieves up to 10 faster convergence while adding almost zerocomputational overhead, and scales well to high-dimensional data where NLL otherwise fails.
Paper Structure (27 sections, 4 theorems, 62 equations, 10 figures, 10 tables, 1 algorithm)

This paper contains 27 sections, 4 theorems, 62 equations, 10 figures, 10 tables, 1 algorithm.

Key Result

Proposition 3.1

Consider a probabilistic model $p(\boldsymbol{x},\boldsymbol{z}|\boldsymbol\theta)$ with observed variables $\boldsymbol{x}$, latent variables $\boldsymbol{z}$, and parameters $\boldsymbol\theta$. We can show that where $Q(\boldsymbol\theta | \boldsymbol\theta_{t})$ is the M-step objective eq:10

Figures (10)

  • Figure 1: (Left) A standard Gaussian distribution compared with (right) a multimodal Gaussian mixture model. Both share the same mean $\mathbb{E}[x]=0$ while the Gaussian mixture is more flexible.
  • Figure 2: Fitting two Gaussians with a Gaussian mixture model (GMM) in $\mathbb{R}^2$. (a) Mode collapse using nll loss. (b) Mode separation using ngem loss. Heatmaps denote the probability density of the learned GMM, while stars ($\star$) denote means of the ground truth Gaussians. Marginals distributions are also displayed with hashed gray regions representing the ground truth Gaussians and red regions representing the learned mixture density. (c) Negative log-likelihoods ($\downarrow$) of GMMs with a learning rate $\beta=10^{-2}$, averaged ($\pm$ std) across 5 random seeds.
  • Figure 3: Key concepts and connections in \ref{['sec:3.2']}.
  • Figure 4: Negative log-likelihood ($\downarrow$) of the learned GMM on the Two-Gaussians example, with different learning rates $\beta$. Results averaged ($\pm$ std) across 5 random seeds.
  • Figure 5: Trajectories of the GMM ($K=2$) component means on the Two-Gaussians example during training. Stars ($\star$) represent the ground-truth Gaussian means, white squares/circles represent the initial component means, and their red counterparts represent the trajectory of component means logged periodically during training.
  • ...and 5 more figures

Theorems & Definitions (8)

  • Proposition 3.1: salakhutdinov2003optimizationxu2024toward
  • proof
  • Corollary 3.2
  • proof
  • Proposition 3.3: sato2001online
  • proof
  • Proposition 3.4
  • proof