Conditional diffusion model with spatial attention and latent embedding for medical image segmentation
Behzad Hejrati, Soumyanil Banerjee, Carri Glide-Hurst, Ming Dong
TL;DR
This work targets the slow inference of diffusion-based medical image segmentation by introducing cDAL, a conditional diffusion model that uses per-step discriminators to extract spatial attention and a latent embedding to promote multimodality. The method conditions label generation on the input image $I$, uses an attention map $A_D$ derived from discriminator features to focus on discriminative regions, and injects a latent variable $z$ to reduce the required diffusion steps to $T \le 4$, enabling fast training and sampling. Across MoNuSeg, Chest X-ray, and Hippocampus datasets, cDAL achieves higher Dice scores and mIoU than state-of-the-art methods, with substantially shorter inference times (about 1 s) compared to diffusion-based baselines that use many steps. The contributions include per-time-step discriminators guiding attention, latent embedding for multimodal denoising, and a practical, faster diffusion-based approach for clinical segmentation, with code publicly available.
Abstract
Diffusion models have been used extensively for high quality image and video generation tasks. In this paper, we propose a novel conditional diffusion model with spatial attention and latent embedding (cDAL) for medical image segmentation. In cDAL, a convolutional neural network (CNN) based discriminator is used at every time-step of the diffusion process to distinguish between the generated labels and the real ones. A spatial attention map is computed based on the features learned by the discriminator to help cDAL generate more accurate segmentation of discriminative regions in an input image. Additionally, we incorporated a random latent embedding into each layer of our model to significantly reduce the number of training and sampling time-steps, thereby making it much faster than other diffusion models for image segmentation. We applied cDAL on 3 publicly available medical image segmentation datasets (MoNuSeg, Chest X-ray and Hippocampus) and observed significant qualitative and quantitative improvements with higher Dice scores and mIoU over the state-of-the-art algorithms. The source code is publicly available at https://github.com/Hejrati/cDAL/.
