Table of Contents
Fetching ...

Transformer-based Time-to-Event Prediction for Chronic Kidney Disease Deterioration

Moshe Zisser, Dvir Aran

TL;DR

This work tackles time-to-event prediction in chronic kidney disease using large-scale health claims data. It introduces STRAFE, a transformer-based survival-analysis model that embeds SNOMED concepts, applies self-attention to capture temporal and cross-visit context, and uses a convolutional head to generate monthly survival predictions $q(t|X)$ over a horizon of up to $T_{max}$ months. STRAFE outperforms traditional baselines (e.g., RSF, DeepHit) for exact time-to-event prediction and, when trained on censored data, improves fixed-time risk prediction as measured by AUC-ROC, while enabling per-patient explainability via attention visualizations. The method shows potential for targeted, early interventions in large healthcare datasets, offering both improved risk stratification and interpretable patient narratives to guide care management.

Abstract

Deep-learning techniques, particularly the transformer model, have shown great potential in enhancing the prediction performance of longitudinal health records. While previous methods have mainly focused on fixed-time risk prediction, time-to-event prediction (also known as survival analysis) is often more appropriate for clinical scenarios. Here, we present a novel deep-learning architecture we named STRAFE, a generalizable survival analysis transformer-based architecture for electronic health records. The performance of STRAFE was evaluated using a real-world claim dataset of over 130,000 individuals with stage 3 chronic kidney disease (CKD) and was found to outperform other time-to-event prediction algorithms in predicting the exact time of deterioration to stage 5. Additionally, STRAFE was found to outperform binary outcome algorithms in predicting fixed-time risk, possibly due to its ability to train on censored data. We show that STRAFE predictions can improve the positive predictive value of high-risk patients by 3-fold, demonstrating possible usage to improve targeting for intervention programs. Finally, we suggest a novel visualization approach to predictions on a per-patient basis. In conclusion, STRAFE is a cutting-edge time-to-event prediction algorithm that has the potential to enhance risk predictions in large claims datasets.

Transformer-based Time-to-Event Prediction for Chronic Kidney Disease Deterioration

TL;DR

This work tackles time-to-event prediction in chronic kidney disease using large-scale health claims data. It introduces STRAFE, a transformer-based survival-analysis model that embeds SNOMED concepts, applies self-attention to capture temporal and cross-visit context, and uses a convolutional head to generate monthly survival predictions over a horizon of up to months. STRAFE outperforms traditional baselines (e.g., RSF, DeepHit) for exact time-to-event prediction and, when trained on censored data, improves fixed-time risk prediction as measured by AUC-ROC, while enabling per-patient explainability via attention visualizations. The method shows potential for targeted, early interventions in large healthcare datasets, offering both improved risk stratification and interpretable patient narratives to guide care management.

Abstract

Deep-learning techniques, particularly the transformer model, have shown great potential in enhancing the prediction performance of longitudinal health records. While previous methods have mainly focused on fixed-time risk prediction, time-to-event prediction (also known as survival analysis) is often more appropriate for clinical scenarios. Here, we present a novel deep-learning architecture we named STRAFE, a generalizable survival analysis transformer-based architecture for electronic health records. The performance of STRAFE was evaluated using a real-world claim dataset of over 130,000 individuals with stage 3 chronic kidney disease (CKD) and was found to outperform other time-to-event prediction algorithms in predicting the exact time of deterioration to stage 5. Additionally, STRAFE was found to outperform binary outcome algorithms in predicting fixed-time risk, possibly due to its ability to train on censored data. We show that STRAFE predictions can improve the positive predictive value of high-risk patients by 3-fold, demonstrating possible usage to improve targeting for intervention programs. Finally, we suggest a novel visualization approach to predictions on a per-patient basis. In conclusion, STRAFE is a cutting-edge time-to-event prediction algorithm that has the potential to enhance risk predictions in large claims datasets.
Paper Structure (21 sections, 9 equations, 5 figures, 3 tables)

This paper contains 21 sections, 9 equations, 5 figures, 3 tables.

Figures (5)

  • Figure 1: STRAFE architecture. Visits of patients consist of a set of concepts. The concepts of each visit are transformed into a concept embedding vector from a pre-trained embedding. These vectors, together with a temporal embedding vector, are fed into a self-attention layer which outputs contextualized visits. Using a convolution layer these are transformed into representations per month. A second self-attention mechanism is used to contextualize the months, and a MLP layer is used to extract survival probabilities per month.
  • Figure 2: Cohort selection procedure. A de-identified claims database from a large US health insurance was used. A cohort of stage 3 CKD patients was identified and divided to train and test sets using 80/20 split. Time-to-event algorithms train on the full cohort. while fixed-time algorithms can only use non-censored patients.
  • Figure 3: Time-to-event prediction. a. Study design: the observation period includes all visits prior to and including the first stage 3 CKD indication. Patients were then followed up until a stage 5 CKD indication or the last claim. b. Example of predicted survival curves for one patient who had an event after 34 months according to three algorithms. A survival curve that presents a sharp decline around the time of the actual event is considered an informative curve. c. Boxplots of duration time-to-event in each decile. Test set patients were grouped into deciles according to the predicted mean survival time.
  • Figure 4: Fixed-time prediction.a. An illustration of transforming a time-to-event prediction to fixed time prediction - the probability at a fixed-time is used as the predicted probability. b. Boxplots of bootstrapped AUC-ROC values for three algorithms on the test set in the 24-months task. P-value(STRAFE vs. SARD) = 2e-8. c. Test set patients were divided into deciles according to the predicted probabilities of STRAFE at 12-months. Y-axis shows the percentage of patients who deteriorated to stage 5 CKD. The red line is the percentage in the full test cohort.
  • Figure 5: Per-patient explainability.a. An example heatmap of the interaction values of the attention matrix for a patient that had 12 visits. Visit 1 and 9 have the highest value. b. A graph-based approach for visualizing the connectivity between the visits. Dots were colored according to ICD-10 code chapter of the primary diagnosis of the visit. c. The original predicted survival curve for the patient compared to the predicted survival curve following the removal of visits 1 and 9.