Table of Contents
Fetching ...

Local Patches Meet Global Context: Scalable 3D Diffusion Priors for Computed Tomography Reconstruction

Taewon Yang, Jason Hu, Jeffrey A. Fessler, Liyue Shen

TL;DR

The paper tackles the challenge of learning scalable 3D diffusion priors for high-resolution CT reconstruction under data and compute constraints. It introduces a global-aware 3D patch diffusion model that jointly learns local 3D patches and a downsampled global volume, enabling efficient generation and accurate reconstruction of 3D CT volumes. Through extensive experiments on LIDC-IDRI and AAPM datasets, the approach achieves state-of-the-art performance in sparse-view CT reconstruction and demonstrates faster inference compared to baselines, while providing detailed ablations on key design choices. The work highlights the value of integrating local patch statistics with global context to form a coherent 3D prior and discusses avenues for improving robustness to less structured data.

Abstract

Diffusion models learn strong image priors that can be leveraged to solve inverse problems like medical image reconstruction. However, for real-world applications such as 3D Computed Tomography (CT) imaging, directly training diffusion models on 3D data presents significant challenges due to the high computational demands of extensive GPU resources and large-scale datasets. Existing works mostly reuse 2D diffusion priors to address 3D inverse problems, but fail to fully realize and leverage the generative capacity of diffusion models for high-dimensional data. In this study, we propose a novel 3D patch-based diffusion model that can learn a fully 3D diffusion prior from limited data, enabling scalable generation of high-resolution 3D images. Our core idea is to learn the prior of 3D patches to achieve scalable efficiency, while coupling local and global information to guarantee high-quality 3D image generation, by modeling the joint distribution of position-aware 3D local patches and downsampled 3D volume as global context. Our approach not only enables high-quality 3D generation, but also offers an unprecedentedly efficient and accurate solution to high-resolution 3D inverse problems. Experiments on 3D CT reconstruction across multiple datasets show that our method outperforms state-of-the-art methods in both performance and efficiency, notably achieving high-resolution 3D reconstruction of $512 \times 512 \times 256$ ($\sim$20 mins).

Local Patches Meet Global Context: Scalable 3D Diffusion Priors for Computed Tomography Reconstruction

TL;DR

The paper tackles the challenge of learning scalable 3D diffusion priors for high-resolution CT reconstruction under data and compute constraints. It introduces a global-aware 3D patch diffusion model that jointly learns local 3D patches and a downsampled global volume, enabling efficient generation and accurate reconstruction of 3D CT volumes. Through extensive experiments on LIDC-IDRI and AAPM datasets, the approach achieves state-of-the-art performance in sparse-view CT reconstruction and demonstrates faster inference compared to baselines, while providing detailed ablations on key design choices. The work highlights the value of integrating local patch statistics with global context to form a coherent 3D prior and discusses avenues for improving robustness to less structured data.

Abstract

Diffusion models learn strong image priors that can be leveraged to solve inverse problems like medical image reconstruction. However, for real-world applications such as 3D Computed Tomography (CT) imaging, directly training diffusion models on 3D data presents significant challenges due to the high computational demands of extensive GPU resources and large-scale datasets. Existing works mostly reuse 2D diffusion priors to address 3D inverse problems, but fail to fully realize and leverage the generative capacity of diffusion models for high-dimensional data. In this study, we propose a novel 3D patch-based diffusion model that can learn a fully 3D diffusion prior from limited data, enabling scalable generation of high-resolution 3D images. Our core idea is to learn the prior of 3D patches to achieve scalable efficiency, while coupling local and global information to guarantee high-quality 3D image generation, by modeling the joint distribution of position-aware 3D local patches and downsampled 3D volume as global context. Our approach not only enables high-quality 3D generation, but also offers an unprecedentedly efficient and accurate solution to high-resolution 3D inverse problems. Experiments on 3D CT reconstruction across multiple datasets show that our method outperforms state-of-the-art methods in both performance and efficiency, notably achieving high-resolution 3D reconstruction of (20 mins).

Paper Structure

This paper contains 30 sections, 17 equations, 18 figures, 7 tables, 2 algorithms.

Figures (18)

  • Figure 1: Training the 3D patch diffusion model. Noisy 3D patch $\bm{G}\xspace_c\bm{x}\xspace_t$ (local details), downsampled volume $\bm{D}\xspace \bm{x}\xspace_t$ (global context), and positional encoding are concatenated as inputs to the denoiser network $D_{\theta}$ in each training iteration, to predict the noise $\bm{\epsilon}\xspace_\theta$ in the input noisy 3D patch. Positional encoding consists of voxel-based coordinates in $x,y,z$ that are normalized to $[-1,1]$.
  • Figure 2: Schematic illustration for zero padding and partitioning image into 3D patches. Each index $i$ represents one of $P^3$ possible ways to choose a patch offset tuple.
  • Figure 3: Unconditional 3D image generation results using the LIDC-IDRI prior. The top row shows axial, coronal, and sagittal slices from a generated volume, and the bottom row shows the corresponding slices from its nearest-neighbor volume in the training dataset. The slice indices for the axial, coronal, and sagittal views are $[30, 80, 130, 180, 230]$, $[70, 100, 130, 160, 190]$, and $[60, 130, 160, 190, 210]$, respectively
  • Figure 4: Results of our proposed method and comparison methods for 20 view CT recon on LIDC $256 \times 256\times 256$ dataset. Images are shown in modified Hounsfield units. The top row shows the axial slice and the bottom row shows the sagittal slice from the reconstructed volume.
  • Figure 5: Sagittal slices of generated 3D volumes for different numbers of different sampling steps. Generated image quality degrades if the number of steps is reduced too much.
  • ...and 13 more figures