Table of Contents
Fetching ...

Temporal Supervised Contrastive Learning for Modeling Patient Risk Progression

Shahriar Noroozizadeh, Jeremy C. Weiss, George H. Chen

TL;DR

This work introduces Temporal-SCL, an embedding-centric framework for modeling patient risk progression in variable-length tabular time series. By learning hyperspherical embeddings for each time step and enforcing temporal smoothness via a dedicated temporal network, the approach captures both outcome-consistent similarity and adjacent-time-step proximity, while using a nearest-neighbor pairing in raw feature space to preserve raw feature heterogeneity without relying on clustering during training. The model is pre-trained with a snapshot-level supervised contrastive loss, then jointly trained with a temporal regularization term, and finally deployed with a predictor head to produce time-step level class probabilities. Across synthetic data and real-world clinical datasets (MIMIC-III and ADNI), Temporal-SCL outperforms state-of-the-art baselines and demonstrates robustness to ablations; the nearest-neighbor pairing is shown to be pivotal for structure recovery and predictive performance. The work also provides heatmap visualizations that relate embedding clusters to raw features and outcomes, offering clinically interpretable insights and a path toward trajectory-aware decision support.

Abstract

We consider the problem of predicting how the likelihood of an outcome of interest for a patient changes over time as we observe more of the patient data. To solve this problem, we propose a supervised contrastive learning framework that learns an embedding representation for each time step of a patient time series. Our framework learns the embedding space to have the following properties: (1) nearby points in the embedding space have similar predicted class probabilities, (2) adjacent time steps of the same time series map to nearby points in the embedding space, and (3) time steps with very different raw feature vectors map to far apart regions of the embedding space. To achieve property (3), we employ a nearest neighbor pairing mechanism in the raw feature space. This mechanism also serves as an alternative to data augmentation, a key ingredient of contrastive learning, which lacks a standard procedure that is adequately realistic for clinical tabular data, to our knowledge. We demonstrate that our approach outperforms state-of-the-art baselines in predicting mortality of septic patients (MIMIC-III dataset) and tracking progression of cognitive impairment (ADNI dataset). Our method also consistently recovers the correct synthetic dataset embedding structure across experiments, a feat not achieved by baselines. Our ablation experiments show the pivotal role of our nearest neighbor pairing.

Temporal Supervised Contrastive Learning for Modeling Patient Risk Progression

TL;DR

This work introduces Temporal-SCL, an embedding-centric framework for modeling patient risk progression in variable-length tabular time series. By learning hyperspherical embeddings for each time step and enforcing temporal smoothness via a dedicated temporal network, the approach captures both outcome-consistent similarity and adjacent-time-step proximity, while using a nearest-neighbor pairing in raw feature space to preserve raw feature heterogeneity without relying on clustering during training. The model is pre-trained with a snapshot-level supervised contrastive loss, then jointly trained with a temporal regularization term, and finally deployed with a predictor head to produce time-step level class probabilities. Across synthetic data and real-world clinical datasets (MIMIC-III and ADNI), Temporal-SCL outperforms state-of-the-art baselines and demonstrates robustness to ablations; the nearest-neighbor pairing is shown to be pivotal for structure recovery and predictive performance. The work also provides heatmap visualizations that relate embedding clusters to raw features and outcomes, offering clinically interpretable insights and a path toward trajectory-aware decision support.

Abstract

We consider the problem of predicting how the likelihood of an outcome of interest for a patient changes over time as we observe more of the patient data. To solve this problem, we propose a supervised contrastive learning framework that learns an embedding representation for each time step of a patient time series. Our framework learns the embedding space to have the following properties: (1) nearby points in the embedding space have similar predicted class probabilities, (2) adjacent time steps of the same time series map to nearby points in the embedding space, and (3) time steps with very different raw feature vectors map to far apart regions of the embedding space. To achieve property (3), we employ a nearest neighbor pairing mechanism in the raw feature space. This mechanism also serves as an alternative to data augmentation, a key ingredient of contrastive learning, which lacks a standard procedure that is adequately realistic for clinical tabular data, to our knowledge. We demonstrate that our approach outperforms state-of-the-art baselines in predicting mortality of septic patients (MIMIC-III dataset) and tracking progression of cognitive impairment (ADNI dataset). Our method also consistently recovers the correct synthetic dataset embedding structure across experiments, a feat not achieved by baselines. Our ablation experiments show the pivotal role of our nearest neighbor pairing.
Paper Structure (31 sections, 6 equations, 9 figures, 5 tables)

This paper contains 31 sections, 6 equations, 9 figures, 5 tables.

Figures (9)

  • Figure 3.1: Overview of Temporal-SCL
  • Figure 3.2: Heatmap showing how features (rows) vary across clusters (columns) for the sepsis cohort of the MIMIC dataset. Heatmap intensity values can be thought of as the conditional probability of seeing a feature value (row) conditioned on being in a cluster (column); these probabilities are estimated using test set snapshots. Columns are ordered left to right in increasing fraction of test set snapshots that come from a time series that has a final outcome of death.
  • Figure 4.1: Synthetic dataset: panel (a) shows the only 4 possible time series trajectories (each true embedding vector state has a unique color-shape combination; there are 10 such states); every time series has 3 time steps and belongs to one of two classes red/blue. Panels (b)-(e) show learned embedding spaces of four methods; only Temporal-SCL correctly recovers the 10 ground truth states. A version of this figure with embeddings of all methods evaluated is in Fig. \ref{['fig:toyDataset_result_complete']}.
  • Figure A.1: Setup for MIMIC Experiments at Inference time from Temporal-SCL: First, we extract time series data for each patient from a time window around their sepsis onset within their complete ICU data timeline. This time series data is transformed into multiple time steps at 4-hour intervals. The resulting time series will have its features at each time step mapped onto the embedding space learned by our encoder. Finally, these resulting embeddings will be passed through our predictor network to predict the ICU mortality of each patient at every time step.
  • Figure A.2: Setup for ADNI Experiment at Inference time from Temporal-SCL: First, we extract time series data for each patient from their complete timeline of available data spanning all the 6-month follow-up visits. This time series data maintains the same 6-month interval time steps as the raw data, with each time step having its own true class label representing one of the three possible brain function states. The resulting time series will have its features at each time step mapped onto the embedding space learned by our encoder. Subsequently, these resulting embeddings will be passed through our predictor network to predict the brain function class of each patient at every time step.
  • ...and 4 more figures