Cycle Diffusion Model for Counterfactual Image Generation
Fangrui Huang, Alan Wang, Binxu Li, Bailey Trang, Ridvan Yesiloglu, Tianyu Hua, Wei Peng, Ehsan Adeli
TL;DR
This work addresses conditioning-faithful medical image synthesis by introducing Cycle Diffusion Model (CDM), which couples a cycle-consistent training objective with diffusion-based generation to produce accurate direct and counterfactual 3D brain MRIs conditioned on age and sex. CDM extends latent diffusion models with bidirectional (counterfactual and factual) generation and a cycle-regularization term, optimized via a composite loss that includes standard LDM denoising terms and a cycle-consistency penalty. Evaluated on a large, multi-study 3D brain MRI dataset, CDM outperforms baselines in conditioning accuracy (lower age MAE, higher sex accuracy), image quality (FID), and diversity (MS-SSIM), while also producing anatomically realistic age-related changes in counterfactuals. The approach offers potential for targeted data augmentation and disease progression modeling, though it incurs higher inference cost and requires careful consideration of distribution shifts and clinical validation for deployment.
Abstract
Deep generative models have demonstrated remarkable success in medical image synthesis. However, ensuring conditioning faithfulness and high-quality synthetic images for direct or counterfactual generation remains a challenge. In this work, we introduce a cycle training framework to fine-tune diffusion models for improved conditioning adherence and enhanced synthetic image realism. Our approach, Cycle Diffusion Model (CDM), enforces consistency between generated and original images by incorporating cycle constraints, enabling more reliable direct and counterfactual generation. Experiments on a combined 3D brain MRI dataset (from ABCD, HCP aging & young adults, ADNI, and PPMI) show that our method improves conditioning accuracy and enhances image quality as measured by FID and SSIM. The results suggest that the cycle strategy used in CDM can be an effective method for refining diffusion-based medical image generation, with applications in data augmentation, counterfactual, and disease progression modeling.
