Table of Contents
Fetching ...

Diffusion-PINN Sampler

Zhekun Shi, Longlin Yu, Tianyu Xie, Cheng Zhang

TL;DR

The Diffusion-PINN Sampler is introduced, a novel diffusion-based sampling algorithm that estimates the drift term by solving the governing partial differential equation of the log-density of the underlying SDE marginals via physics-informed neural networks (PINN).

Abstract

Recent success of diffusion models has inspired a surge of interest in developing sampling techniques using reverse diffusion processes. However, accurately estimating the drift term in the reverse stochastic differential equation (SDE) solely from the unnormalized target density poses significant challenges, hindering existing methods from achieving state-of-the-art performance. In this paper, we introduce the Diffusion-PINN Sampler (DPS), a novel diffusion-based sampling algorithm that estimates the drift term by solving the governing partial differential equation of the log-density of the underlying SDE marginals via physics-informed neural networks (PINN). We prove that the error of log-density approximation can be controlled by the PINN residual loss, enabling us to establish convergence guarantees of DPS. Experiments on a variety of sampling tasks demonstrate the effectiveness of our approach, particularly in accurately identifying mixing proportions when the target contains isolated components.

Diffusion-PINN Sampler

TL;DR

The Diffusion-PINN Sampler is introduced, a novel diffusion-based sampling algorithm that estimates the drift term by solving the governing partial differential equation of the log-density of the underlying SDE marginals via physics-informed neural networks (PINN).

Abstract

Recent success of diffusion models has inspired a surge of interest in developing sampling techniques using reverse diffusion processes. However, accurately estimating the drift term in the reverse stochastic differential equation (SDE) solely from the unnormalized target density poses significant challenges, hindering existing methods from achieving state-of-the-art performance. In this paper, we introduce the Diffusion-PINN Sampler (DPS), a novel diffusion-based sampling algorithm that estimates the drift term by solving the governing partial differential equation of the log-density of the underlying SDE marginals via physics-informed neural networks (PINN). We prove that the error of log-density approximation can be controlled by the PINN residual loss, enabling us to establish convergence guarantees of DPS. Experiments on a variety of sampling tasks demonstrate the effectiveness of our approach, particularly in accurately identifying mixing proportions when the target contains isolated components.

Paper Structure

This paper contains 44 sections, 11 theorems, 107 equations, 7 figures, 7 tables, 2 algorithms.

Key Result

Theorem 1

Assume the density $p_t(\bm{x})$ is sufficiently smooth on $\mathbb{R}^d\times[0,T]$. Then for all $(\bm{x}, t) \in \mathbb{R}^d\times[0,T]$, the log-density $u_t(\bm{x}):=\log p_t(\bm{x})$ satisfies the PDE and the score $\bm{s}_t(\bm{x}):=\nabla_{\bm{x}}\log p_t(\bm{x})$ satisfies the PDE

Figures (7)

  • Figure 1: Left: KL divergence, Fisher divergence, and log-density error between $\pi^M$ and $\tilde{\pi}^M$ as functions of $w_1$, where $\tilde{w}_1=0.2$ and $\bm{a}=(-5,-5)'$. Middle/Right: The evolution of log-density error/Fisher divergence along the forward process respectively. The forward process achieves standard Gaussian at $t=1$.
  • Figure 2: Comparison between solving log-density FPE by PINN and denoising score matching on score estimation.
  • Figure 3: Sampling performance of different methods for 9-Gaussians ($d=2$), Rings ($d=2$), Funnel ($d=10$), and Double-well ($d=30$).
  • Figure 4: Left: PINN residual loss and score approximation error during solving score/log-density FPE; Middle/Right: Marginals of the first dimension from DPS by solving score/log-density FPE for MoG with two modes.
  • Figure 5: Left: KL divergence to the ground truth during solving log-density FPE with different regularization for Funnel. Middle/Right: Sampling performance of DPS with/without regularization for Funnel.
  • ...and 2 more figures

Theorems & Definitions (20)

  • Theorem 1: Log-density FPE and score FPE; Proposition 3.1 in lai2023fp
  • Example 1
  • Theorem 2
  • Remark 1
  • Theorem 3
  • Theorem 4
  • Remark 2
  • Theorem 5
  • proof : Proof of Theorem \ref{['thm: fp']}
  • Theorem 6
  • ...and 10 more