Table of Contents
Fetching ...

Scalable Mean-Field Variational Inference via Preconditioned Primal-Dual Optimization

Jinhua Lyu, Tianmin Yu, Ying Ma, Naichen Shi

TL;DR

This paper proposes PD-VI, a mini-batch primal–dual framework for scalable mean-field variational inference, and its block-preconditioned extension P$^2$D-VI to address heterogeneity across parameter blocks. By reformulating MFVI as a consensus-constrained finite-sum problem and applying an augmented Lagrangian with careful block-wise preconditioning, the method achieves stable, constant-step updates and provable convergence: $O(1/T)$ in nonconvex settings and linear convergence under strong convexity. The approach jointly updates global and local variational parameters using per-sample dual variables, enabling efficient mini-batch optimization without diminishing stepsizes, even for non-conjugate models. Empirical results on synthetic Gaussian mixtures and a large-scale spatial transcriptomics dataset show faster convergence and higher-quality solutions than standard stochastic variational inference baselines, demonstrating practical impact for large-scale Bayesian inference.

Abstract

In this work, we investigate the large-scale mean-field variational inference (MFVI) problem from a mini-batch primal-dual perspective. By reformulating MFVI as a constrained finite-sum problem, we develop a novel primal-dual algorithm based on an augmented Lagrangian formulation, termed primal-dual variational inference (PD-VI). PD-VI jointly updates global and local variational parameters in the evidence lower bound in a scalable manner. To further account for heterogeneous loss geometry across different variational parameter blocks, we introduce a block-preconditioned extension, P$^2$D-VI, which adapts the primal-dual updates to the geometry of each parameter block and improves both numerical robustness and practical efficiency. We establish convergence guarantees for both PD-VI and P$^2$D-VI under properly chosen constant step size, without relying on conjugacy assumptions or explicit bounded-variance conditions. In particular, we prove $O(1/T)$ convergence to a stationary point in general settings and linear convergence under strong convexity. Numerical experiments on synthetic data and a real large-scale spatial transcriptomics dataset demonstrate that our methods consistently outperform existing stochastic variational inference approaches in terms of convergence speed and solution quality.

Scalable Mean-Field Variational Inference via Preconditioned Primal-Dual Optimization

TL;DR

This paper proposes PD-VI, a mini-batch primal–dual framework for scalable mean-field variational inference, and its block-preconditioned extension PD-VI to address heterogeneity across parameter blocks. By reformulating MFVI as a consensus-constrained finite-sum problem and applying an augmented Lagrangian with careful block-wise preconditioning, the method achieves stable, constant-step updates and provable convergence: in nonconvex settings and linear convergence under strong convexity. The approach jointly updates global and local variational parameters using per-sample dual variables, enabling efficient mini-batch optimization without diminishing stepsizes, even for non-conjugate models. Empirical results on synthetic Gaussian mixtures and a large-scale spatial transcriptomics dataset show faster convergence and higher-quality solutions than standard stochastic variational inference baselines, demonstrating practical impact for large-scale Bayesian inference.

Abstract

In this work, we investigate the large-scale mean-field variational inference (MFVI) problem from a mini-batch primal-dual perspective. By reformulating MFVI as a constrained finite-sum problem, we develop a novel primal-dual algorithm based on an augmented Lagrangian formulation, termed primal-dual variational inference (PD-VI). PD-VI jointly updates global and local variational parameters in the evidence lower bound in a scalable manner. To further account for heterogeneous loss geometry across different variational parameter blocks, we introduce a block-preconditioned extension, PD-VI, which adapts the primal-dual updates to the geometry of each parameter block and improves both numerical robustness and practical efficiency. We establish convergence guarantees for both PD-VI and PD-VI under properly chosen constant step size, without relying on conjugacy assumptions or explicit bounded-variance conditions. In particular, we prove convergence to a stationary point in general settings and linear convergence under strong convexity. Numerical experiments on synthetic data and a real large-scale spatial transcriptomics dataset demonstrate that our methods consistently outperform existing stochastic variational inference approaches in terms of convergence speed and solution quality.
Paper Structure (32 sections, 29 theorems, 209 equations, 5 figures, 1 table, 3 algorithms)

This paper contains 32 sections, 29 theorems, 209 equations, 5 figures, 1 table, 3 algorithms.

Key Result

Theorem 4.6

Assume that at each iteration $t$, the mini-batch $S_t$ is sampled uniformly at random with fixed size $|S_t|=m$. $\boldsymbol{\phi}$ and $\boldsymbol{\lambda}$ represents local and global variables, respectively. Algorithm alg:meta and alg:oracle_1 satisfies: (i) (Strongly convex). Under Assumption

Figures (5)

  • Figure 1: Ground truth probability density of a Gaussian mixture model, and the probability density reconstructed by SVI with diminishing stepsizes, and by P$^2$D-VI (our algorithm).
  • Figure 2: Loss curves on a strongly convex quadratic problem.
  • Figure 3: Wasserstein distance between the true and variational Gaussian mixture distributions on the synthetic dataset.
  • Figure 4: Convergence behavior and clustering performance on the MOSTA dataset. From left to right, we show the evolution of the variational objective value (excluding constant terms) as a function of iteration count and wall-clock time, the adjusted Rand index (ARI) versus iteration, and the norm of the gradient with respect to the global variables versus iteration (displayed on a logarithmic scale). Solid lines represent the mean across runs, and shaded regions indicate one standard deviation over five random seeds. For clarity of visualization, the iteration range and wall-clock time shown correspond to a truncated segment of the full training process; extending the runs beyond the displayed range does not change the observed behavior.
  • Figure 5: Spatial domains identified by P$^2$D-VI and competing methods across sagittal sections of mouse embryos. Ground-truth anatomical and tissue annotations are obtained from the MOSTA reference atlas generated in the original study (left). Clustering performance of different methods is quantified using the adjusted Rand index (ARI), where higher ARI values indicate better agreement with the reference annotations.

Theorems & Definitions (57)

  • Remark 3.1
  • Theorem 4.6
  • Theorem 4.7
  • Theorem 2.3
  • proof
  • Lemma 2.4
  • proof
  • Lemma 2.5
  • proof
  • Lemma 2.6
  • ...and 47 more