Table of Contents
Fetching ...

Latent Drifting in Diffusion Models for Counterfactual Medical Image Synthesis

Yousef Yeganeh, Azade Farshad, Ioannis Charisiadis, Marta Hasny, Martin Hartenberger, Björn Ommer, Nassir Navab, Ehsan Adeli

TL;DR

Medical diffusion models suffer from data scarcity and distribution shift when transferring from natural to medical imagery. Latent Drifting (LD) introduces a latent-space drift $\delta$ into the diffusion process to align the pre-trained distribution $\mathcal{D}_\theta$ with a target medical distribution $\mathcal{D}_{GT}$ via a counterfactual objective, usable with any fine-tuning method. The approach formalizes conditioning through a min–max formulation and demonstrates robust improvements in counterfactual medical image generation and manipulation on longitudinal brain MRI and CheXpert datasets. This yields high-fidelity, clinically relevant counterfactuals while preserving data privacy and reducing the need for large medical training sets, with potential impact on prognosis, aging, and disease-modification studies.

Abstract

Scaling by training on large datasets has been shown to enhance the quality and fidelity of image generation and manipulation with diffusion models; however, such large datasets are not always accessible in medical imaging due to cost and privacy issues, which contradicts one of the main applications of such models to produce synthetic samples where real data is scarce. Also, fine-tuning pre-trained general models has been a challenge due to the distribution shift between the medical domain and the pre-trained models. Here, we propose Latent Drift (LD) for diffusion models that can be adopted for any fine-tuning method to mitigate the issues faced by the distribution shift or employed in inference time as a condition. Latent Drifting enables diffusion models to be conditioned for medical images fitted for the complex task of counterfactual image generation, which is crucial to investigate how parameters such as gender, age, and adding or removing diseases in a patient would alter the medical images. We evaluate our method on three public longitudinal benchmark datasets of brain MRI and chest X-rays for counterfactual image generation. Our results demonstrate significant performance gains in various scenarios when combined with different fine-tuning schemes.

Latent Drifting in Diffusion Models for Counterfactual Medical Image Synthesis

TL;DR

Medical diffusion models suffer from data scarcity and distribution shift when transferring from natural to medical imagery. Latent Drifting (LD) introduces a latent-space drift into the diffusion process to align the pre-trained distribution with a target medical distribution via a counterfactual objective, usable with any fine-tuning method. The approach formalizes conditioning through a min–max formulation and demonstrates robust improvements in counterfactual medical image generation and manipulation on longitudinal brain MRI and CheXpert datasets. This yields high-fidelity, clinically relevant counterfactuals while preserving data privacy and reducing the need for large medical training sets, with potential impact on prognosis, aging, and disease-modification studies.

Abstract

Scaling by training on large datasets has been shown to enhance the quality and fidelity of image generation and manipulation with diffusion models; however, such large datasets are not always accessible in medical imaging due to cost and privacy issues, which contradicts one of the main applications of such models to produce synthetic samples where real data is scarce. Also, fine-tuning pre-trained general models has been a challenge due to the distribution shift between the medical domain and the pre-trained models. Here, we propose Latent Drift (LD) for diffusion models that can be adopted for any fine-tuning method to mitigate the issues faced by the distribution shift or employed in inference time as a condition. Latent Drifting enables diffusion models to be conditioned for medical images fitted for the complex task of counterfactual image generation, which is crucial to investigate how parameters such as gender, age, and adding or removing diseases in a patient would alter the medical images. We evaluate our method on three public longitudinal benchmark datasets of brain MRI and chest X-rays for counterfactual image generation. Our results demonstrate significant performance gains in various scenarios when combined with different fine-tuning schemes.
Paper Structure (14 sections, 5 equations, 8 figures, 4 tables)

This paper contains 14 sections, 5 equations, 8 figures, 4 tables.

Figures (8)

  • Figure 1: Medical Image Generation and Manipulation using LD. Left to right: (1) (text, image)-to-image without ground-truth (GT) pairs : Removal, : Addition , (2) Aging: (text, image)-to-image with GT pairs, (3) text-to-image without GT pairs.
  • Figure 2: Samples generated with identical sampled noise and varying latent drift ($\delta \in [-0.1,0.1]$) at inference, with different text prompts using the pre-trained Stable Diffusion rombach2022high. Top: Elon Musk on a mountain, Bottom: Barack Obama on a plane.
  • Figure 3: Image and Latent Space Distribution w. and w/o. LD in Fine-tuning + Sampling. Rows 1-2: sampled images with different latent drift parameters ($\delta$) during the inference. Row 3: channel-wise distribution change in images during the reverse sampling process. Row 4: distribution of the latent space $z_0$ during in reverse sampling.
  • Figure 4: MRI Slice Generation for Cognitively Normal (CN) and Alzheimer's Disease (AD) after fine-tuning Stable Diffusion with LD using different methods.
  • Figure 5: Image Generation w. and w/o. LD during fine-tuning. Examples generated from left to right using Textual Inversion gal2022image, DreamBooth ruiz2023dreambooth, Custom Diffusion kumari2023multi, and Basic FT rombach2022high.
  • ...and 3 more figures