Table of Contents
Fetching ...

Faster Sampling via Stochastic Gradient Proximal Sampler

Xunpeng Huang, Difan Zou, Yi-An Ma, Hanze Dong, Tong Zhang

TL;DR

This work addresses scalable unbiased sampling from non-log-concave targets by introducing a stochastic proximal sampler (SPS) that extends the proximal sampler to finite-sum objectives. The framework replaces the full energy f with minibatch energies f_{\mathbf{b}} and alternates between sampling x and y under a randomized joint target, implemented via SPS-SGLD or SPS-MALA inner samplers. The authors establish nonasymptotic TV guarantees under a second-moment control regime and LSIs, achieving gradient complexities of $\tilde{O}(d\epsilon^{-2})$ for SPS-SGLD and $\tilde{O}(d^{1/2}\epsilon^{-2})$ for SPS-MALA, outperforming prior stochastic-gradient methods. Experiments on synthetic data corroborate the theoretical gains, showing reduced TV error and competitive compute time, thus demonstrating the practicality and efficiency of stochastic proximal sampling for large-scale, potentially non-convex targets.

Abstract

Stochastic gradients have been widely integrated into Langevin-based methods to improve their scalability and efficiency in solving large-scale sampling problems. However, the proximal sampler, which exhibits much faster convergence than Langevin-based algorithms in the deterministic setting Lee et al. (2021), has yet to be explored in its stochastic variants. In this paper, we study the Stochastic Proximal Samplers (SPS) for sampling from non-log-concave distributions. We first establish a general framework for implementing stochastic proximal samplers and establish the convergence theory accordingly. We show that the convergence to the target distribution can be guaranteed as long as the second moment of the algorithm trajectory is bounded and restricted Gaussian oracles can be well approximated. We then provide two implementable variants based on Stochastic gradient Langevin dynamics (SGLD) and Metropolis-adjusted Langevin algorithm (MALA), giving rise to SPS-SGLD and SPS-MALA. We further show that SPS-SGLD and SPS-MALA can achieve $ε$-sampling error in total variation (TV) distance within $\tilde{\mathcal{O}}(dε^{-2})$ and $\tilde{\mathcal{O}}(d^{1/2}ε^{-2})$ gradient complexities, which outperform the best-known result by at least an $\tilde{\mathcal{O}}(d^{1/3})$ factor. This enhancement in performance is corroborated by our empirical studies on synthetic data with various dimensions, demonstrating the efficiency of our proposed algorithm.

Faster Sampling via Stochastic Gradient Proximal Sampler

TL;DR

This work addresses scalable unbiased sampling from non-log-concave targets by introducing a stochastic proximal sampler (SPS) that extends the proximal sampler to finite-sum objectives. The framework replaces the full energy f with minibatch energies f_{\mathbf{b}} and alternates between sampling x and y under a randomized joint target, implemented via SPS-SGLD or SPS-MALA inner samplers. The authors establish nonasymptotic TV guarantees under a second-moment control regime and LSIs, achieving gradient complexities of for SPS-SGLD and for SPS-MALA, outperforming prior stochastic-gradient methods. Experiments on synthetic data corroborate the theoretical gains, showing reduced TV error and competitive compute time, thus demonstrating the practicality and efficiency of stochastic proximal sampling for large-scale, potentially non-convex targets.

Abstract

Stochastic gradients have been widely integrated into Langevin-based methods to improve their scalability and efficiency in solving large-scale sampling problems. However, the proximal sampler, which exhibits much faster convergence than Langevin-based algorithms in the deterministic setting Lee et al. (2021), has yet to be explored in its stochastic variants. In this paper, we study the Stochastic Proximal Samplers (SPS) for sampling from non-log-concave distributions. We first establish a general framework for implementing stochastic proximal samplers and establish the convergence theory accordingly. We show that the convergence to the target distribution can be guaranteed as long as the second moment of the algorithm trajectory is bounded and restricted Gaussian oracles can be well approximated. We then provide two implementable variants based on Stochastic gradient Langevin dynamics (SGLD) and Metropolis-adjusted Langevin algorithm (MALA), giving rise to SPS-SGLD and SPS-MALA. We further show that SPS-SGLD and SPS-MALA can achieve -sampling error in total variation (TV) distance within and gradient complexities, which outperform the best-known result by at least an factor. This enhancement in performance is corroborated by our empirical studies on synthetic data with various dimensions, demonstrating the efficiency of our proposed algorithm.
Paper Structure (25 sections, 29 theorems, 252 equations, 2 figures, 5 tables, 4 algorithms)

This paper contains 25 sections, 29 theorems, 252 equations, 2 figures, 5 tables, 4 algorithms.

Key Result

Theorem 3.1

Suppose Assumption con_ass:lips_loss-con_ass:var_bound hold, and Alg. alg:sps satisfies: Then, we have

Figures (2)

  • Figure 1: The background of all graphs is the projection of the negative log density on a $2$d plane, and nodes are the projection of particles returned by different algorithms on the same plane. The first two rows show the distribution of particles' projection after different iterations of SGLD and SPS-SGLD with their optimal step sizes when $d=10$.
  • Figure 2: The graph in the left column shows the TV distance estimation, i.e., $\mathrm{TV}(\hat{p}_K, p_*)$ when SGLD and SPS-SGLD chose their optimal hyper-parameters under different dimensions. The graph in the right column denotes the TV distance estimation when SGLD and SPS-SGLD chose different step sizes and $d=10$.

Theorems & Definitions (51)

  • Theorem 3.1
  • Lemma 3.2
  • Theorem 4.1
  • Theorem 4.2
  • Lemma A.1
  • proof
  • Lemma B.1: variant of data-processing inequality
  • proof
  • Lemma B.2: strong log-concavity and smoothness of inner target functions
  • proof
  • ...and 41 more