Table of Contents
Fetching ...

PSI3D: Plug-and-Play 3D Stochastic Inference with Slice-wise Latent Diffusion Prior

Wenhan Guo, Jinglun Yu, Yaning Wang, Jin U. Kang, Yu Sun

TL;DR

PSI3D extends plug-and-play Bayesian inference to massive 3D volumes by performing slice-wise latent-diffusion priors coupled with depth-wise TV regularization in a three-step Gibbs sampling scheme. The method uses a VQGAN latent encoder to operate diffusion in latent space, enabling scalable sampling and uncertainty quantification for high-dimensional inverse problems. Experiments on OCT super-resolution show superior reconstruction quality and credible posterior inferences compared to traditional and learning-based baselines. The approach is modular and broadly applicable to other imaging modalities and forward models, offering robust 3D inference at scale.

Abstract

Diffusion models are highly expressive image priors for Bayesian inverse problems. However, most diffusion models cannot operate on large-scale, high-dimensional data due to high training and inference costs. In this work, we introduce a Plug-and-play algorithm for 3D stochastic inference with latent diffusion prior (PSI3D) to address massive ($1024\times 1024\times 128$) volumes. Specifically, we formulate a Markov chain Monte Carlo approach to reconstruct each two-dimensional (2D) slice by sampling from a 2D latent diffusion model. To enhance inter-slice consistency, we also incorporate total variation (TV) regularization stochastically along the concatenation axis. We evaluate our performance on optical coherence tomography (OCT) super-resolution. Our method significantly improves reconstruction quality for large-scale scientific imaging compared to traditional and learning-based baselines, while providing robust and credible reconstructions.

PSI3D: Plug-and-Play 3D Stochastic Inference with Slice-wise Latent Diffusion Prior

TL;DR

PSI3D extends plug-and-play Bayesian inference to massive 3D volumes by performing slice-wise latent-diffusion priors coupled with depth-wise TV regularization in a three-step Gibbs sampling scheme. The method uses a VQGAN latent encoder to operate diffusion in latent space, enabling scalable sampling and uncertainty quantification for high-dimensional inverse problems. Experiments on OCT super-resolution show superior reconstruction quality and credible posterior inferences compared to traditional and learning-based baselines. The approach is modular and broadly applicable to other imaging modalities and forward models, offering robust 3D inference at scale.

Abstract

Diffusion models are highly expressive image priors for Bayesian inverse problems. However, most diffusion models cannot operate on large-scale, high-dimensional data due to high training and inference costs. In this work, we introduce a Plug-and-play algorithm for 3D stochastic inference with latent diffusion prior (PSI3D) to address massive () volumes. Specifically, we formulate a Markov chain Monte Carlo approach to reconstruct each two-dimensional (2D) slice by sampling from a 2D latent diffusion model. To enhance inter-slice consistency, we also incorporate total variation (TV) regularization stochastically along the concatenation axis. We evaluate our performance on optical coherence tomography (OCT) super-resolution. Our method significantly improves reconstruction quality for large-scale scientific imaging compared to traditional and learning-based baselines, while providing robust and credible reconstructions.

Paper Structure

This paper contains 5 sections, 15 equations, 3 figures, 1 table, 1 algorithm.

Figures (3)

  • Figure 1: Visual comparison of the 3D OCT volumes ($1024 \times 1024 \times 128$) reconstructed by PSI3D and 3D UNet. From left to right, we show the ground truth volume, PSI3D reconstructed volume, absolute error for PSI3D, and absolute error for 3D UNet. Note that PSI3D accurately reconstructs the volume with minimal artifacts and a higher peak signal-to-noise ratio (PSNR).
  • Figure 2: Visual comparison of slices from 3D reconstructions obtained using PSI3D and baseline methods.. Each row shows a single B-scan slice in the OCT volume, with a zoom-in view in the yellow boxes. Note that PSI3D accurately recovers fine anatomical details and achieves higher PSNR and SSIM than baseline methods.
  • Figure 3: Visualization of pixel-wise statistics associated with the example volume shown in Fig. \ref{['fig:volumes']}. We plot the volume's absolute error ($|\bar{\mathbf{x}} - \mathbf{x}|$), $3\times$ standard deviation ($\text{SD}_\mathbf{x}$), and 3-SD credible interval with $10,000$ voxels randomly selected from the volume.