Faster Sampling via Stochastic Gradient Proximal Sampler
Xunpeng Huang, Difan Zou, Yi-An Ma, Hanze Dong, Tong Zhang
TL;DR
This work addresses scalable unbiased sampling from non-log-concave targets by introducing a stochastic proximal sampler (SPS) that extends the proximal sampler to finite-sum objectives. The framework replaces the full energy f with minibatch energies f_{\mathbf{b}} and alternates between sampling x and y under a randomized joint target, implemented via SPS-SGLD or SPS-MALA inner samplers. The authors establish nonasymptotic TV guarantees under a second-moment control regime and LSIs, achieving gradient complexities of $\tilde{O}(d\epsilon^{-2})$ for SPS-SGLD and $\tilde{O}(d^{1/2}\epsilon^{-2})$ for SPS-MALA, outperforming prior stochastic-gradient methods. Experiments on synthetic data corroborate the theoretical gains, showing reduced TV error and competitive compute time, thus demonstrating the practicality and efficiency of stochastic proximal sampling for large-scale, potentially non-convex targets.
Abstract
Stochastic gradients have been widely integrated into Langevin-based methods to improve their scalability and efficiency in solving large-scale sampling problems. However, the proximal sampler, which exhibits much faster convergence than Langevin-based algorithms in the deterministic setting Lee et al. (2021), has yet to be explored in its stochastic variants. In this paper, we study the Stochastic Proximal Samplers (SPS) for sampling from non-log-concave distributions. We first establish a general framework for implementing stochastic proximal samplers and establish the convergence theory accordingly. We show that the convergence to the target distribution can be guaranteed as long as the second moment of the algorithm trajectory is bounded and restricted Gaussian oracles can be well approximated. We then provide two implementable variants based on Stochastic gradient Langevin dynamics (SGLD) and Metropolis-adjusted Langevin algorithm (MALA), giving rise to SPS-SGLD and SPS-MALA. We further show that SPS-SGLD and SPS-MALA can achieve $ε$-sampling error in total variation (TV) distance within $\tilde{\mathcal{O}}(dε^{-2})$ and $\tilde{\mathcal{O}}(d^{1/2}ε^{-2})$ gradient complexities, which outperform the best-known result by at least an $\tilde{\mathcal{O}}(d^{1/3})$ factor. This enhancement in performance is corroborated by our empirical studies on synthetic data with various dimensions, demonstrating the efficiency of our proposed algorithm.
