Learning Mixture Density via Natural Gradient Expectation Maximization
Yutao Chen, Jasmine Bayrooti, Steven Morad
TL;DR
This work tackles slow convergence and mode collapse in training mixture density networks (MDNs). By reframing MDNs as latent-variable models and uniting expectation-maximization with natural gradient geometry, the authors derive natural gradient EM (ngem), a tractable objective that preconditions updates with a block-diagonal complete-data Fisher information. Empirically, ngem yields up to 10x faster convergence, better mode separation, and robust performance on high-dimensional tasks like inverse MNIST, while incurring negligible overhead and complementary gains with curvature-aware optimizers. The approach relies on diagonal Gaussian and categorical FIMs to enable component-wise, efficient updates, offering a practical enhancement for uncertainty-aware, multimodal regression with MDNs.
Abstract
Mixture density networks are neural networks that produce Gaussian mixtures to represent continuous multimodal conditional densities. Standard training procedures involve maximum likelihood estimation using the negative log-likelihood (NLL) objective, which suffers from slow convergence and mode collapse. In this work, we improve the optimization of mixture density networks by integrating their information geometry. Specifically, we interpret mixture density networks as deep latent-variable models and analyze them through an expectation maximization framework, which reveals surprising theoretical connections to natural gradient descent. We then exploit such connections to derive the natural gradient expectation maximization (nGEM) objective. We show that empirically nGEM achieves up to 10$\times$ faster convergence while adding almost zerocomputational overhead, and scales well to high-dimensional data where NLL otherwise fails.
