Temporal Supervised Contrastive Learning for Modeling Patient Risk Progression
Shahriar Noroozizadeh, Jeremy C. Weiss, George H. Chen
TL;DR
This work introduces Temporal-SCL, an embedding-centric framework for modeling patient risk progression in variable-length tabular time series. By learning hyperspherical embeddings for each time step and enforcing temporal smoothness via a dedicated temporal network, the approach captures both outcome-consistent similarity and adjacent-time-step proximity, while using a nearest-neighbor pairing in raw feature space to preserve raw feature heterogeneity without relying on clustering during training. The model is pre-trained with a snapshot-level supervised contrastive loss, then jointly trained with a temporal regularization term, and finally deployed with a predictor head to produce time-step level class probabilities. Across synthetic data and real-world clinical datasets (MIMIC-III and ADNI), Temporal-SCL outperforms state-of-the-art baselines and demonstrates robustness to ablations; the nearest-neighbor pairing is shown to be pivotal for structure recovery and predictive performance. The work also provides heatmap visualizations that relate embedding clusters to raw features and outcomes, offering clinically interpretable insights and a path toward trajectory-aware decision support.
Abstract
We consider the problem of predicting how the likelihood of an outcome of interest for a patient changes over time as we observe more of the patient data. To solve this problem, we propose a supervised contrastive learning framework that learns an embedding representation for each time step of a patient time series. Our framework learns the embedding space to have the following properties: (1) nearby points in the embedding space have similar predicted class probabilities, (2) adjacent time steps of the same time series map to nearby points in the embedding space, and (3) time steps with very different raw feature vectors map to far apart regions of the embedding space. To achieve property (3), we employ a nearest neighbor pairing mechanism in the raw feature space. This mechanism also serves as an alternative to data augmentation, a key ingredient of contrastive learning, which lacks a standard procedure that is adequately realistic for clinical tabular data, to our knowledge. We demonstrate that our approach outperforms state-of-the-art baselines in predicting mortality of septic patients (MIMIC-III dataset) and tracking progression of cognitive impairment (ADNI dataset). Our method also consistently recovers the correct synthetic dataset embedding structure across experiments, a feat not achieved by baselines. Our ablation experiments show the pivotal role of our nearest neighbor pairing.
