PSI3D: Plug-and-Play 3D Stochastic Inference with Slice-wise Latent Diffusion Prior
Wenhan Guo, Jinglun Yu, Yaning Wang, Jin U. Kang, Yu Sun
TL;DR
PSI3D extends plug-and-play Bayesian inference to massive 3D volumes by performing slice-wise latent-diffusion priors coupled with depth-wise TV regularization in a three-step Gibbs sampling scheme. The method uses a VQGAN latent encoder to operate diffusion in latent space, enabling scalable sampling and uncertainty quantification for high-dimensional inverse problems. Experiments on OCT super-resolution show superior reconstruction quality and credible posterior inferences compared to traditional and learning-based baselines. The approach is modular and broadly applicable to other imaging modalities and forward models, offering robust 3D inference at scale.
Abstract
Diffusion models are highly expressive image priors for Bayesian inverse problems. However, most diffusion models cannot operate on large-scale, high-dimensional data due to high training and inference costs. In this work, we introduce a Plug-and-play algorithm for 3D stochastic inference with latent diffusion prior (PSI3D) to address massive ($1024\times 1024\times 128$) volumes. Specifically, we formulate a Markov chain Monte Carlo approach to reconstruct each two-dimensional (2D) slice by sampling from a 2D latent diffusion model. To enhance inter-slice consistency, we also incorporate total variation (TV) regularization stochastically along the concatenation axis. We evaluate our performance on optical coherence tomography (OCT) super-resolution. Our method significantly improves reconstruction quality for large-scale scientific imaging compared to traditional and learning-based baselines, while providing robust and credible reconstructions.
