Fast Compute for ML Optimization
Nick Polson, Vadim Sokolov
TL;DR
The paper introduces Scale Mixture EM (SM-EM), a tuning-free optimizer for losses that admit a variance-mean scale-mixture representation. By applying latent-variable data augmentation (e.g., Pólya–Gamma), SM-EM rewrites updates as a weighted least-squares M-step with a model-derived precision $\tau^{-2}\hat{\Lambda}+X^\top\hat{\Omega}X$, where $\hat{\Omega}$ and $\hat{\Lambda}$ are adaptive, per-iteration weights. The approach unifies proximal and adaptive-gradient perspectives, connecting to Adam/AdamW while deriving curvature- and shrinkage-based weights from the loss geometry; it also enables acceleration with Nesterov and pathwise amortization for regularization grids. Empirically, SM-EM achieves substantially lower final losses than tuned Adam on ill-conditioned logistic benchmarks, with strong gains when extended with Nesterov, and can accelerate regularization paths via shared sufficient statistics and Halton Monte Carlo for large-scale M-steps. The framework offers a principled, model-based alternative to heuristic adaptive methods, with potential extensions to stochastic online variants and broader loss classes.
Abstract
We study optimization for losses that admit a variance-mean scale-mixture representation. Under this representation, each EM iteration is a weighted least squares update in which latent variables determine observation and parameter weights; these play roles analogous to Adam's second-moment scaling and AdamW's weight decay, but are derived from the model. The resulting Scale Mixture EM (SM-EM) algorithm removes user-specified learning-rate and momentum schedules. On synthetic ill-conditioned logistic regression benchmarks with $p \in \{20, \ldots, 500\}$, SM-EM with Nesterov acceleration attains up to $13\times$ lower final loss than Adam tuned by learning-rate grid search. For a 40-point regularization path, sharing sufficient statistics across penalty values yields a $10\times$ runtime reduction relative to the same tuned-Adam protocol. For the base (non-accelerated) algorithm, EM monotonicity guarantees nonincreasing objective values; adding Nesterov extrapolation trades this guarantee for faster empirical convergence.
