Table of Contents
Fetching ...

Guided Discrete Diffusion for Electronic Health Record Generation

Jun Han, Zixiang Chen, Yongqian Li, Yiwen Kou, Eran Halperin, Robert E. Tillman, Quanquan Gu

TL;DR

A novel tabular EHR generation method, EHR-D3PM, which enables both unconditional and conditional generation using the discrete diffusion model and significantly outperforms existing generative baselines on comprehensive fidelity and utility metrics while maintaining less attribute and membership vulnerability risks.

Abstract

Electronic health records (EHRs) are a pivotal data source that enables numerous applications in computational medicine, e.g., disease progression prediction, clinical trial design, and health economics and outcomes research. Despite wide usability, their sensitive nature raises privacy and confidentially concerns, which limit potential use cases. To tackle these challenges, we explore the use of generative models to synthesize artificial, yet realistic EHRs. While diffusion-based methods have recently demonstrated state-of-the-art performance in generating other data modalities and overcome the training instability and mode collapse issues that plague previous GAN-based approaches, their applications in EHR generation remain underexplored. The discrete nature of tabular medical code data in EHRs poses challenges for high-quality data generation, especially for continuous diffusion models. To this end, we introduce a novel tabular EHR generation method, EHR-D3PM, which enables both unconditional and conditional generation using the discrete diffusion model. Our experiments demonstrate that EHR-D3PM significantly outperforms existing generative baselines on comprehensive fidelity and utility metrics while maintaining less attribute and membership vulnerability risks. Furthermore, we show EHR-D3PM is effective as a data augmentation method and enhances performance on downstream tasks when combined with real data.

Guided Discrete Diffusion for Electronic Health Record Generation

TL;DR

A novel tabular EHR generation method, EHR-D3PM, which enables both unconditional and conditional generation using the discrete diffusion model and significantly outperforms existing generative baselines on comprehensive fidelity and utility metrics while maintaining less attribute and membership vulnerability risks.

Abstract

Electronic health records (EHRs) are a pivotal data source that enables numerous applications in computational medicine, e.g., disease progression prediction, clinical trial design, and health economics and outcomes research. Despite wide usability, their sensitive nature raises privacy and confidentially concerns, which limit potential use cases. To tackle these challenges, we explore the use of generative models to synthesize artificial, yet realistic EHRs. While diffusion-based methods have recently demonstrated state-of-the-art performance in generating other data modalities and overcome the training instability and mode collapse issues that plague previous GAN-based approaches, their applications in EHR generation remain underexplored. The discrete nature of tabular medical code data in EHRs poses challenges for high-quality data generation, especially for continuous diffusion models. To this end, we introduce a novel tabular EHR generation method, EHR-D3PM, which enables both unconditional and conditional generation using the discrete diffusion model. Our experiments demonstrate that EHR-D3PM significantly outperforms existing generative baselines on comprehensive fidelity and utility metrics while maintaining less attribute and membership vulnerability risks. Furthermore, we show EHR-D3PM is effective as a data augmentation method and enhances performance on downstream tasks when combined with real data.
Paper Structure (35 sections, 18 equations, 9 figures, 9 tables)

This paper contains 35 sections, 18 equations, 9 figures, 9 tables.

Figures (9)

  • Figure 1: Comparison of prevalence on synthetic data and real data $\mathcal{D}_2$ with ICD, CPT and GEN codes, where the total dimension is 2683. The second row represents the prevalence of the first row in the low data regime. The prevalence is computed on 200K samples. The dashed diagonal lines represent the perfect matching of code prevalence between synthetic data and real EHR data. Pearson correlations are very high for all methods and thus not used as a metric to compare different methods.
  • Figure 2: Density comparison of per-record feature number on synthetic data and real data $\mathcal{D}_2$ with ICD, CPT and GEN codes. The number of features per record is the sum of ICD codes present in each sample. The number of bins is 175, and the range of feature number values is (0, 175).
  • Figure 3: Synthetic data augmentation for disease classifications from ICD codes based on dataset $\mathcal{D}_2$. The size of the real source data for training the LGBM classifier is 5000, as indicated by the dashed purple line. We augment the source training data with synthetic data to train the LGBM classifier. "Uncond Samples" stands for the synthetic data generated by our unconditional sampler. Guided samples are synthetic data generated by our proposed guided sampler for each disease. To minimize noise from evaluation, we adopt 200K real test data to evaluate all experiments and report test AUROC for comparison. 80% of the test data are bootstrapped 50 times to compute 95% CI, which is visualized by the shaded region around each line.
  • Figure 4: Architecture of our denoise model. (b) provides the detail of transformer block which has linear complexity with respect to the dimension of input. Axial positional embedding is employed to encode the positional information. We employ sinusoidal positional embedding to time $t$ to the time embedding and then use a two-layer MLPs to map the time embedding into hidden state. In the first layer of the two-layer MLP, we use Softplus activation function. We apply L times of such two-layer MLP to get the hidden state of time embedding to yield the input of each transformer block, as indicated in (a). Positional embedding is added to the embedding of discrete inputs. The input has dimension N and B means the batch size. For notation simplicity, we use all dimension of tabular data has K categories. We use one-hot representation and therefore, the output of the denoise model has shape (B, N, K). The shape of intermediate layers is provided in (a). In (b), "Proj" denotes the projection operation proposed in Linformer wang2020linformer, which induces the linear complexity of the attention module with respect to the input dimension $N$. The projection dimension is set as the default value 128 for all experiments in this paper.
  • Figure 5: Comparison of prevalence in synthetic data and real data (MIMIC). The second row represents the prevalence of the first row in the low data regime. The prevalence is computed on 10K samples as the MIMIC dataset is relatively small. The dashed diagonal lines represent the perfect matching of code prevalence between synthetic data and real EHR data. Pearson correlations are very high for all methods and thus not used as a metric to compare different methods.
  • ...and 4 more figures