Neural Topic Models with Survival Supervision: Jointly Predicting Time-to-Event Outcomes and Learning How Clinical Features Relate
George H. Chen, Linhong Li, Ren Zuo, Amanda Coston, Jeremy C. Weiss
TL;DR
This work tackles time-to-event prediction in healthcare with high-dimensional, heterogeneous data by introducing Neural Survival-Supervised Topic Models that jointly learn a topic representation of clinical events and a survival model. The approach combines neural topic models (LDA or SAGE) with classical survival models (Cox or AFT), using the learned topic weights as inputs to the survival branch and introducing a background topic for interpretability. It demonstrates four model variants (LDA-cox, LDA-aft, SAGE-cox, SAGE-aft) and evaluates them on seven clinical datasets, showing competitive predictive accuracy and yielding interpretable clinical topics via heatmaps and post-hoc filtering. The paper also details practical considerations for visualization, scalability, and potential extensions to broader data types, with discussion of limitations and avenues for theory and improved interpretability.
Abstract
We present a neural network framework for learning a survival model to predict a time-to-event outcome while simultaneously learning a topic model that reveals feature relationships. In particular, we model each subject as a distribution over "topics", where a topic could, for instance, correspond to an age group, a disorder, or a disease. The presence of a topic in a subject means that specific clinical features are more likely to appear for the subject. Topics encode information about related features and are learned in a supervised manner to predict a time-to-event outcome. Our framework supports combining many different topic and survival models; training the resulting joint survival-topic model readily scales to large datasets using standard neural net optimizers with minibatch gradient descent. For example, a special case is to combine LDA with a Cox model, in which case a subject's distribution over topics serves as the input feature vector to the Cox model. We explain how to address practical implementation issues that arise when applying these neural survival-supervised topic models to clinical data, including how to visualize results to assist clinical interpretation. We study the effectiveness of our proposed framework on seven clinical datasets on predicting time until death as well as hospital ICU length of stay, where we find that neural survival-supervised topic models achieve competitive accuracy with existing approaches while yielding interpretable clinical topics that explain feature relationships. Our code is available at: https://github.com/georgehc/survival-topics
