Table of Contents
Fetching ...

Tensor train based sampling algorithms for approximating regularized Wasserstein proximal operators

Fuqun Han, Stanley Osher, Wuchen Li

TL;DR

This work develops a tensor-train (TT) based sampling framework that leverages a regularized Wasserstein proximal operator to approximate the density evolution of overdamped Langevin dynamics in high dimensions. By expressing the kernel through a TT representation, the authors achieve scalable computation and storage benefits, with rigorous unbiasedness and linear convergence demonstrated in the Gaussian setting. They introduce TT-BRWP, a noise-free, unbiased variant that uses a carefully chosen covariance update to stabilize density estimation, and provide theoretical analyses for Gaussian and simplified Bayesian inverse problems along with practical computational considerations. Extensive numerical experiments across Gaussian, multimodal, nonconvex, and Bayesian inverse problems show that TT-BRWP outperforms classical Langevin-type samplers and BRWP with MC integration in accuracy, convergence speed, and robustness. The proposed method has the potential to impact high-dimensional Bayesian inference and inverse problems by enabling efficient, accurate sampling in settings where traditional MCMC approaches struggle.

Abstract

We present a tensor train (TT) based algorithm designed for sampling from a target distribution and employ TT approximation to capture the high-dimensional probability density evolution of overdamped Langevin dynamics. This involves utilizing the regularized Wasserstein proximal operator, which exhibits a simple kernel integration formulation, i.e., the softmax formula of the traditional proximal operator. The integration, performed in $\mathbb{R}^d$, poses a challenge in practical scenarios, making the algorithm practically implementable only with the aid of TT approximation. In the specific context of Gaussian distributions, we rigorously establish the unbiasedness and linear convergence of our sampling algorithm towards the target distribution. To assess the effectiveness of our proposed methods, we apply them to various scenarios, including Gaussian families, Gaussian mixtures, bimodal distributions, and Bayesian inverse problems in numerical examples. The sampling algorithm exhibits superior accuracy and faster convergence when compared to classical Langevin dynamics-type sampling algorithms.

Tensor train based sampling algorithms for approximating regularized Wasserstein proximal operators

TL;DR

This work develops a tensor-train (TT) based sampling framework that leverages a regularized Wasserstein proximal operator to approximate the density evolution of overdamped Langevin dynamics in high dimensions. By expressing the kernel through a TT representation, the authors achieve scalable computation and storage benefits, with rigorous unbiasedness and linear convergence demonstrated in the Gaussian setting. They introduce TT-BRWP, a noise-free, unbiased variant that uses a carefully chosen covariance update to stabilize density estimation, and provide theoretical analyses for Gaussian and simplified Bayesian inverse problems along with practical computational considerations. Extensive numerical experiments across Gaussian, multimodal, nonconvex, and Bayesian inverse problems show that TT-BRWP outperforms classical Langevin-type samplers and BRWP with MC integration in accuracy, convergence speed, and robustness. The proposed method has the potential to impact high-dimensional Bayesian inference and inverse problems by enabling efficient, accurate sampling in settings where traditional MCMC approaches struggle.

Abstract

We present a tensor train (TT) based algorithm designed for sampling from a target distribution and employ TT approximation to capture the high-dimensional probability density evolution of overdamped Langevin dynamics. This involves utilizing the regularized Wasserstein proximal operator, which exhibits a simple kernel integration formulation, i.e., the softmax formula of the traditional proximal operator. The integration, performed in , poses a challenge in practical scenarios, making the algorithm practically implementable only with the aid of TT approximation. In the specific context of Gaussian distributions, we rigorously establish the unbiasedness and linear convergence of our sampling algorithm towards the target distribution. To assess the effectiveness of our proposed methods, we apply them to various scenarios, including Gaussian families, Gaussian mixtures, bimodal distributions, and Bayesian inverse problems in numerical examples. The sampling algorithm exhibits superior accuracy and faster convergence when compared to classical Langevin dynamics-type sampling algorithms.
Paper Structure (14 sections, 9 theorems, 102 equations, 10 figures, 1 table, 1 algorithm)

This paper contains 14 sections, 9 theorems, 102 equations, 10 figures, 1 table, 1 algorithm.

Key Result

Theorem 3.1

TT_accuracyLet $f \in H^{k+1}(\mathbb{R}^d)$ for some fixed $k > 0$ and $0 < \epsilon < 1$. Then, the overall truncation error of the TT decomposition for $f$ with ranks $r \leq \epsilon^{-d/k}$ is given by and the storage cost for TT representation will be $\epsilon^{-d/k}$.

Figures (10)

  • Figure 1: Logarithm of the approximation error for variance versus iteration using empirical distribution (red) and density estimation defined in \ref{['def_KDE']} (blue) with $\sigma = 2,0.5,0.25$ (from left to right), $T = 0.1$, $h=0.1$.
  • Figure 2: Logarithm of the approximation error of Algorithm \ref{['TT_Algo']} for variance versus iterations using different terminal time $T$ with $T= 0.05,0.1,0.2,0.4$, $h=0.1$, $\sigma = 0.5$ (Left) and $1$ (Right).
  • Figure 3: Example 1. The logarithm of the $L^2$ error for the variance with samples generated from TT-BRWP (blue), BRWP (red), and ULA (black).
  • Figure 4: Example 2. The error of mean for samples generated for TT-BRWP (blue), BRWP (red), and ULA (black) versus iterations.
  • Figure 5: Example 3. Evolution of particles for different algorithms after 10 iterations (first row) and 15 iterations (second row) under different initial distributions. The contour lines are $0.8$ and $0.3$ of density function.
  • ...and 5 more figures

Theorems & Definitions (13)

  • Theorem 3.1
  • Theorem 3.2
  • Theorem 3.3
  • proof
  • Lemma 4.1
  • Lemma 4.2
  • proof
  • Theorem 4.3
  • proof
  • Theorem 4.4
  • ...and 3 more