Gradient-based Explanations for Deep Learning Survival Models
Sophie Hanna Langbein, Niklas Koenen, Marvin N. Wright
TL;DR
This work tackles the opacity of deep learning models in time-to-event prediction by crafting a formal gradient-based explainability framework for survival neural networks that yields time-dependent, post-hoc insights. It introduces GradSHAP(t) alongside existing gradient-based and attribution methods (Grad(t), SG(t), G×I(t), IntGrad(t)) and analyzes their theoretical underpinnings for survival data. The authors demonstrate that these methods reveal both time-invariant and time-varying feature effects, offer favorable speed-accuracy trade-offs compared to SurvSHAP(t) and SurvLIME, and extend to multi-modal medical data including images, with clear visualizations for temporal dynamics. The approach enhances transparency in clinical decision-making for personalized medicine, while providing an accessible R/torch implementation to facilitate adoption and further research; future work will extend to competing risks, multi-state models, and interactions among features.
Abstract
Deep learning survival models often outperform classical methods in time-to-event predictions, particularly in personalized medicine, but their "black box" nature hinders broader adoption. We propose a framework for gradient-based explanation methods tailored to survival neural networks, extending their use beyond regression and classification. We analyze the implications of their theoretical assumptions for time-dependent explanations in the survival setting and propose effective visualizations incorporating the temporal dimension. Experiments on synthetic data show that gradient-based methods capture the magnitude and direction of local and global feature effects, including time dependencies. We introduce GradSHAP(t), a gradient-based counterpart to SurvSHAP(t), which outperforms SurvSHAP(t) and SurvLIME in a computational speed vs. accuracy trade-off. Finally, we apply these methods to medical data with multi-modal inputs, revealing relevant tabular features and visual patterns, as well as their temporal dynamics.
