Table of Contents
Fetching ...

Latent Diffusion for Medical Image Segmentation: End to end learning for fast sampling and accuracy

Fahim Ahmed Zaman, Mathews Jacob, Amanda Chang, Kan Liu, Milan Sonka, Xiaodong Wu

TL;DR

This work tackles the inefficiencies of diffusion probabilistic models (DPMs) in medical image segmentation by introducing LDSeg, an end-to-end conditional latent diffusion framework. LDSeg jointly learns latent representations of object shapes and source images, and trains a latent-space denoiser with segmentation-aware objectives, enabling fast posterior sampling and improved accuracy. Across Echo, GlaS, and Knee datasets, LDSeg achieves state-of-the-art segmentation performance while dramatically reducing inference time and memory usage, and it demonstrates robust performance under noisy input conditions. The approach provides practical benefits for high-dimensional medical imaging, including 2D and 3D data, with the ability to quantify segmentation uncertainty via multiple sampling runs.

Abstract

Diffusion Probabilistic Models (DPMs) suffer from inefficient inference due to their slow sampling and high memory consumption, which limits their applicability to various medical imaging applications. In this work, we propose a novel conditional diffusion modeling framework (LDSeg) for medical image segmentation, utilizing the learned inherent low-dimensional latent shape manifolds of the target objects and the embeddings of the source image with an end-to-end framework. Conditional diffusion in latent space not only ensures accurate image segmentation for multiple interacting objects, but also tackles the fundamental issues of traditional DPM-based segmentation methods: (1) high memory consumption, (2) time-consuming sampling process, and (3) unnatural noise injection in the forward and reverse processes. The end-to-end training strategy enables robust representation learning in the latent space related to segmentation features, ensuring significantly faster sampling from the posterior distribution for segmentation generation in the inference phase. Our experiments demonstrate that LDSeg achieved state-of-the-art segmentation accuracy on three medical image datasets with different imaging modalities. In addition, we showed that our proposed model was significantly more robust to noise compared to traditional deterministic segmentation models. The code is available at https://github.com/FahimZaman/LDSeg.git.

Latent Diffusion for Medical Image Segmentation: End to end learning for fast sampling and accuracy

TL;DR

This work tackles the inefficiencies of diffusion probabilistic models (DPMs) in medical image segmentation by introducing LDSeg, an end-to-end conditional latent diffusion framework. LDSeg jointly learns latent representations of object shapes and source images, and trains a latent-space denoiser with segmentation-aware objectives, enabling fast posterior sampling and improved accuracy. Across Echo, GlaS, and Knee datasets, LDSeg achieves state-of-the-art segmentation performance while dramatically reducing inference time and memory usage, and it demonstrates robust performance under noisy input conditions. The approach provides practical benefits for high-dimensional medical imaging, including 2D and 3D data, with the ability to quantify segmentation uncertainty via multiple sampling runs.

Abstract

Diffusion Probabilistic Models (DPMs) suffer from inefficient inference due to their slow sampling and high memory consumption, which limits their applicability to various medical imaging applications. In this work, we propose a novel conditional diffusion modeling framework (LDSeg) for medical image segmentation, utilizing the learned inherent low-dimensional latent shape manifolds of the target objects and the embeddings of the source image with an end-to-end framework. Conditional diffusion in latent space not only ensures accurate image segmentation for multiple interacting objects, but also tackles the fundamental issues of traditional DPM-based segmentation methods: (1) high memory consumption, (2) time-consuming sampling process, and (3) unnatural noise injection in the forward and reverse processes. The end-to-end training strategy enables robust representation learning in the latent space related to segmentation features, ensuring significantly faster sampling from the posterior distribution for segmentation generation in the inference phase. Our experiments demonstrate that LDSeg achieved state-of-the-art segmentation accuracy on three medical image datasets with different imaging modalities. In addition, we showed that our proposed model was significantly more robust to noise compared to traditional deterministic segmentation models. The code is available at https://github.com/FahimZaman/LDSeg.git.
Paper Structure (17 sections, 13 equations, 11 figures, 5 tables, 2 algorithms)

This paper contains 17 sections, 13 equations, 11 figures, 5 tables, 2 algorithms.

Figures (11)

  • Figure 1: The proposed LDSeg model. The label encoder $f_\text{label-enc}$ and image encoder $f_\text{image-enc}$ are used to obtain corresponding low dimensional latent representations $z_{l(0)}$ and $z_i$ for a given ground truth label/mask image $y$ and source image $X$, respectively. A denoiser $f_\text{denoiser}$, conditioned on the source image embedding $z_i$, is used to learn the noise distributions of $z_{l(t)}$ for timesteps $t=1, \dotsc ,T$, where $T$ is the total number of diffusion steps. $z_{l(t)}$ is obtained by perturbing $z_{l(0)}$ with a Gaussian block $\mathcal{G}(\cdot)$ for a given noise variance scheduler $\alpha$ and $\beta$. The cleaned latent space $z_{dn}$ is obtained by subtracting the predicted noise $z_{n(t)}$ from the perturbed one $z_{l(t)}$. Finally, a label decoder $f_{\text{label-dec}}$ is used to obtain the segmentation $\hat{y}$ of the semantic labels in the original image from $z_{dn}$. The model is trained in an end-to-end fashion, where our objective is to learn $q(\hat{y}|X)=\mathbb{E}_{q_{i}(z_{i}|X)}\left[q_{s}(\hat{y}|z)\right]$, where $q_{l}(z \mid y, X) \sim \mathcal{N}(z_{dn}, \sigma^2 \mathrm{I})$. In the inference phase, starting with a random Gaussian $\tilde{z}_{l(T)} \sim \mathcal{N}(\mathrm{0,I})$, the denoiser is iterated for timestep $t=T, \dotsc ,1$ to obtain $\tilde{z}_{l(0)}$ with $z_i$ as the condition. Final segmentation $\hat{y}=f_{\text{label-dec}}(\tilde{z}_{l(0)})$ is obtained using the trained label decoder.
  • Figure 2: A sample GlaS data sirinukunwattana2016gland was used to demonstrate the forward and the reverse diffusion processes. In the forward process (top row), the low-dimensional latent representation $z_{l(0)}$ is first obtained from the label image. Then, Gaussian noise is gradually injected for timestep $t=1, \dotsc ,T$, given the noise variance schedules of $\beta$, where $\epsilon \sim \mathcal{N}(\mathrm{0,I})$. At timestep $T$, $z_{l(T)}$ is subject to $\mathcal{N}(\mathrm{0,I})$. To start the reverse process (bottom row), $\tilde{z}_{l(T)}$ is sampled from $\mathcal{N}(\mathrm{0,I})$. Then the denoiser is used iteratively for timesteps $t=T, \dotsc, 1$ with the source image embedding $z_i$ as the condition. At the end of the reverse process, the segmentation mask is obtained from $\tilde{z}_{l(0)}$ using the trained label decoder.
  • Figure 4: Qualitative segmentation results of different methods for GlaS and Echo dataset, shown in top and bottom rows, respectively. Dark red marks the false negative and the light red marks the false positive error on the segmentation result. GT indicates the ground-truth/label-image.
  • Figure 5: The number of evenly spaced sampling steps vs DSC for different datasets. The DDIM algorithm with only $2$ evenly spaced sampling steps between $1$ and $T=1000$ (inclusive) produced maximum segmentation accuracy for all the datasets. The number of steps are plotted in the logarithmic scale for convenience.
  • Figure 6: (a) The number of sampling steps vs DSCs using LDSeg, LSegDiff and MedSegDiff models for the GlaS dataset. LDSeg was able to achieve the maximum DSC with only $2$ sampling steps, outperforming both LSegDiff ($10$) and MedSegDiff ($700$). (b) Image sizes vs execution times for segmenting a single image with different DPM. The execution times of LDSeg and LSegDiff (both latent diffusion models) remained close to constant due to the use of constrained low-dimensional latent space, while for SDF-DDPM and MedSegDiff, the execution times increased exponentially with increased image sizes.
  • ...and 6 more figures