Table of Contents
Fetching ...

Pretraining Transformer-Based Models on Diffusion-Generated Synthetic Graphs for Alzheimer's Disease Prediction

Abolfazl Moslemi, Hossein Peyvandi

TL;DR

This work tackles early Alzheimer’s disease prediction under limited labeled data and multi-site heterogeneity by proposing a diffusion-based synthetic pretraining framework combined with modality-specific Graph Transformer encoders. A class-conditional DDPM is trained on real NACC data to generate a large, balanced synthetic cohort modeling $p(x|y)$, which pretrains per-modality Graph Transformers before freezing them and training a downstream classifier on real data. The study probes distributional alignment via $MMD$, Fréchet distance, and energy distance, and adds calibration and decision-curve analyses to assess clinical utility. Empirically, the approach achieves superior discrimination (e.g., AUC $=0.914$) over strong baselines like Early/Late Fusion DNNs and MaGNet, demonstrating improved generalization in low-sample, imbalanced clinical settings and suggesting practical value for multimodal AD prediction.

Abstract

Early and accurate detection of Alzheimer's disease (AD) is crucial for enabling timely intervention and improving outcomes. However, developing reliable machine learning (ML) models for AD diagnosis is challenging due to limited labeled data, multi-site heterogeneity, and class imbalance. We propose a Transformer-based diagnostic framework that combines diffusion-based synthetic data generation with graph representation learning and transfer learning. A class-conditional denoising diffusion probabilistic model (DDPM) is trained on the real-world NACC dataset to generate a large synthetic cohort that mirrors multimodal clinical and neuroimaging feature distributions while balancing diagnostic classes. Modality-specific Graph Transformer encoders are first pretrained on this synthetic data to learn robust, class-discriminative representations and are then frozen while a neural classifier is trained on embeddings from the original NACC data. We quantify distributional alignment between real and synthetic cohorts using metrics such as Maximum Mean Discrepancy (MMD), Frechet distance, and energy distance, and complement discrimination metrics with calibration and fixed-specificity sensitivity analyses. Empirically, our framework outperforms standard baselines, including early and late fusion deep neural networks and the multimodal graph-based model MaGNet, yielding higher AUC, accuracy, sensitivity, and specificity under subject-wise cross-validation on NACC. These results show that diffusion-based synthetic pretraining with Graph Transformers can improve generalization in low-sample, imbalanced clinical prediction settings.

Pretraining Transformer-Based Models on Diffusion-Generated Synthetic Graphs for Alzheimer's Disease Prediction

TL;DR

This work tackles early Alzheimer’s disease prediction under limited labeled data and multi-site heterogeneity by proposing a diffusion-based synthetic pretraining framework combined with modality-specific Graph Transformer encoders. A class-conditional DDPM is trained on real NACC data to generate a large, balanced synthetic cohort modeling , which pretrains per-modality Graph Transformers before freezing them and training a downstream classifier on real data. The study probes distributional alignment via , Fréchet distance, and energy distance, and adds calibration and decision-curve analyses to assess clinical utility. Empirically, the approach achieves superior discrimination (e.g., AUC ) over strong baselines like Early/Late Fusion DNNs and MaGNet, demonstrating improved generalization in low-sample, imbalanced clinical settings and suggesting practical value for multimodal AD prediction.

Abstract

Early and accurate detection of Alzheimer's disease (AD) is crucial for enabling timely intervention and improving outcomes. However, developing reliable machine learning (ML) models for AD diagnosis is challenging due to limited labeled data, multi-site heterogeneity, and class imbalance. We propose a Transformer-based diagnostic framework that combines diffusion-based synthetic data generation with graph representation learning and transfer learning. A class-conditional denoising diffusion probabilistic model (DDPM) is trained on the real-world NACC dataset to generate a large synthetic cohort that mirrors multimodal clinical and neuroimaging feature distributions while balancing diagnostic classes. Modality-specific Graph Transformer encoders are first pretrained on this synthetic data to learn robust, class-discriminative representations and are then frozen while a neural classifier is trained on embeddings from the original NACC data. We quantify distributional alignment between real and synthetic cohorts using metrics such as Maximum Mean Discrepancy (MMD), Frechet distance, and energy distance, and complement discrimination metrics with calibration and fixed-specificity sensitivity analyses. Empirically, our framework outperforms standard baselines, including early and late fusion deep neural networks and the multimodal graph-based model MaGNet, yielding higher AUC, accuracy, sensitivity, and specificity under subject-wise cross-validation on NACC. These results show that diffusion-based synthetic pretraining with Graph Transformers can improve generalization in low-sample, imbalanced clinical prediction settings.

Paper Structure

This paper contains 27 sections, 7 equations, 1 table.