Scalable Mean-Field Variational Inference via Preconditioned Primal-Dual Optimization
Jinhua Lyu, Tianmin Yu, Ying Ma, Naichen Shi
TL;DR
This paper proposes PD-VI, a mini-batch primal–dual framework for scalable mean-field variational inference, and its block-preconditioned extension P$^2$D-VI to address heterogeneity across parameter blocks. By reformulating MFVI as a consensus-constrained finite-sum problem and applying an augmented Lagrangian with careful block-wise preconditioning, the method achieves stable, constant-step updates and provable convergence: $O(1/T)$ in nonconvex settings and linear convergence under strong convexity. The approach jointly updates global and local variational parameters using per-sample dual variables, enabling efficient mini-batch optimization without diminishing stepsizes, even for non-conjugate models. Empirical results on synthetic Gaussian mixtures and a large-scale spatial transcriptomics dataset show faster convergence and higher-quality solutions than standard stochastic variational inference baselines, demonstrating practical impact for large-scale Bayesian inference.
Abstract
In this work, we investigate the large-scale mean-field variational inference (MFVI) problem from a mini-batch primal-dual perspective. By reformulating MFVI as a constrained finite-sum problem, we develop a novel primal-dual algorithm based on an augmented Lagrangian formulation, termed primal-dual variational inference (PD-VI). PD-VI jointly updates global and local variational parameters in the evidence lower bound in a scalable manner. To further account for heterogeneous loss geometry across different variational parameter blocks, we introduce a block-preconditioned extension, P$^2$D-VI, which adapts the primal-dual updates to the geometry of each parameter block and improves both numerical robustness and practical efficiency. We establish convergence guarantees for both PD-VI and P$^2$D-VI under properly chosen constant step size, without relying on conjugacy assumptions or explicit bounded-variance conditions. In particular, we prove $O(1/T)$ convergence to a stationary point in general settings and linear convergence under strong convexity. Numerical experiments on synthetic data and a real large-scale spatial transcriptomics dataset demonstrate that our methods consistently outperform existing stochastic variational inference approaches in terms of convergence speed and solution quality.
