Table of Contents
Fetching ...

Gradient-based Explanations for Deep Learning Survival Models

Sophie Hanna Langbein, Niklas Koenen, Marvin N. Wright

TL;DR

This work tackles the opacity of deep learning models in time-to-event prediction by crafting a formal gradient-based explainability framework for survival neural networks that yields time-dependent, post-hoc insights. It introduces GradSHAP(t) alongside existing gradient-based and attribution methods (Grad(t), SG(t), G×I(t), IntGrad(t)) and analyzes their theoretical underpinnings for survival data. The authors demonstrate that these methods reveal both time-invariant and time-varying feature effects, offer favorable speed-accuracy trade-offs compared to SurvSHAP(t) and SurvLIME, and extend to multi-modal medical data including images, with clear visualizations for temporal dynamics. The approach enhances transparency in clinical decision-making for personalized medicine, while providing an accessible R/torch implementation to facilitate adoption and further research; future work will extend to competing risks, multi-state models, and interactions among features.

Abstract

Deep learning survival models often outperform classical methods in time-to-event predictions, particularly in personalized medicine, but their "black box" nature hinders broader adoption. We propose a framework for gradient-based explanation methods tailored to survival neural networks, extending their use beyond regression and classification. We analyze the implications of their theoretical assumptions for time-dependent explanations in the survival setting and propose effective visualizations incorporating the temporal dimension. Experiments on synthetic data show that gradient-based methods capture the magnitude and direction of local and global feature effects, including time dependencies. We introduce GradSHAP(t), a gradient-based counterpart to SurvSHAP(t), which outperforms SurvSHAP(t) and SurvLIME in a computational speed vs. accuracy trade-off. Finally, we apply these methods to medical data with multi-modal inputs, revealing relevant tabular features and visual patterns, as well as their temporal dynamics.

Gradient-based Explanations for Deep Learning Survival Models

TL;DR

This work tackles the opacity of deep learning models in time-to-event prediction by crafting a formal gradient-based explainability framework for survival neural networks that yields time-dependent, post-hoc insights. It introduces GradSHAP(t) alongside existing gradient-based and attribution methods (Grad(t), SG(t), G×I(t), IntGrad(t)) and analyzes their theoretical underpinnings for survival data. The authors demonstrate that these methods reveal both time-invariant and time-varying feature effects, offer favorable speed-accuracy trade-offs compared to SurvSHAP(t) and SurvLIME, and extend to multi-modal medical data including images, with clear visualizations for temporal dynamics. The approach enhances transparency in clinical decision-making for personalized medicine, while providing an accessible R/torch implementation to facilitate adoption and further research; future work will extend to competing risks, multi-state models, and interactions among features.

Abstract

Deep learning survival models often outperform classical methods in time-to-event predictions, particularly in personalized medicine, but their "black box" nature hinders broader adoption. We propose a framework for gradient-based explanation methods tailored to survival neural networks, extending their use beyond regression and classification. We analyze the implications of their theoretical assumptions for time-dependent explanations in the survival setting and propose effective visualizations incorporating the temporal dimension. Experiments on synthetic data show that gradient-based methods capture the magnitude and direction of local and global feature effects, including time dependencies. We introduce GradSHAP(t), a gradient-based counterpart to SurvSHAP(t), which outperforms SurvSHAP(t) and SurvLIME in a computational speed vs. accuracy trade-off. Finally, we apply these methods to medical data with multi-modal inputs, revealing relevant tabular features and visual patterns, as well as their temporal dynamics.

Paper Structure

This paper contains 37 sections, 9 equations, 39 figures, 4 tables.

Figures (39)

  • Figure 1: Overview of our workflow for generating time-dependent post-hoc explanations using gradient-based methods by the example of overall brain cancer survival prediction. The approach utilizes a survival deep learning model with multi-modal input data, providing insights into the temporal dynamics of feature effects through tailored visualizations for different feature types.
  • Figure 2: Mathematical representations of gradient-based feature attribution methods adapted to survival NNs. Each block corresponds to a different underlying objective. For example, in the case of feature-wise relevances $R_{j}^t$ obtained from G$\times$I(t), the goal is to achieve a sum that equals $f(t|\bm{x})$, i.e., $\sum_{j=1}^p R_{j}^t = f(t|\bm{x})$.
  • Figure 3: Grad(t) (top) and G$\times$I(t) (bottom) relevance curves for selected observations using the DeepSurv model trained on the time-independent simulation dataset. The relevance values for each feature are represented by different colors (y-axis) and are plotted across time (x-axis), highlighting the temporal dynamics of feature contributions.
  • Figure 4: Grad(t) relevance curves for the selected observations and the DeepHit model trained on the time-dependent simulation dataset.
  • Figure 5: GradSHAP(t) relevance curves, with corresponding survival prediction curves, reference curve and their difference (top), contribution plots (middle) and force plots (bottom) for the selected observations and the CoxTime model trained on the time-dependent simulation dataset.
  • ...and 34 more figures

Theorems & Definitions (3)

  • Definition 2.1: Survival Function
  • Definition 2.2: Hazard Function
  • Definition 2.3: Cumulative Hazard Function (CHF)