Table of Contents
Fetching ...

Survival Kernets: Scalable and Interpretable Deep Kernel Survival Analysis with an Accuracy Guarantee

George H. Chen

TL;DR

Survival Kernets propose a scalable, interpretable deep kernel survival analysis framework that leverages kernel netting to compress training data into clusters for efficient test-time prediction. A warm-start strategy (tuna) using scalable tree ensembles accelerates training, enabling large-scale evaluation on datasets with millions of points while maintaining competitive accuracy. The model yields interpretable cluster-level visualizations (heatmaps and Kaplan-Meier curves) and provides a finite-sample error bound for a special case, linking kernel-netted predictions to classical Kaplan-Meier estimates. Empirical results on four diverse survival datasets show strong time-dependent concordance performance and substantial training-time savings, with detailed visualizations illustrating cluster-level survival patterns and potential clinical insights.

Abstract

Kernel survival analysis models estimate individual survival distributions with the help of a kernel function, which measures the similarity between any two data points. Such a kernel function can be learned using deep kernel survival models. In this paper, we present a new deep kernel survival model called a survival kernet, which scales to large datasets in a manner that is amenable to model interpretation and also theoretical analysis. Specifically, the training data are partitioned into clusters based on a recently developed training set compression scheme for classification and regression called kernel netting that we extend to the survival analysis setting. At test time, each data point is represented as a weighted combination of these clusters, and each such cluster can be visualized. For a special case of survival kernets, we establish a finite-sample error bound on predicted survival distributions that is, up to a log factor, optimal. Whereas scalability at test time is achieved using the aforementioned kernel netting compression strategy, scalability during training is achieved by a warm-start procedure based on tree ensembles such as XGBoost and a heuristic approach to accelerating neural architecture search. On four standard survival analysis datasets of varying sizes (up to roughly 3 million data points), we show that survival kernets are highly competitive compared to various baselines tested in terms of time-dependent concordance index. Our code is available at: https://github.com/georgehc/survival-kernets

Survival Kernets: Scalable and Interpretable Deep Kernel Survival Analysis with an Accuracy Guarantee

TL;DR

Survival Kernets propose a scalable, interpretable deep kernel survival analysis framework that leverages kernel netting to compress training data into clusters for efficient test-time prediction. A warm-start strategy (tuna) using scalable tree ensembles accelerates training, enabling large-scale evaluation on datasets with millions of points while maintaining competitive accuracy. The model yields interpretable cluster-level visualizations (heatmaps and Kaplan-Meier curves) and provides a finite-sample error bound for a special case, linking kernel-netted predictions to classical Kaplan-Meier estimates. Empirical results on four diverse survival datasets show strong time-dependent concordance performance and substantial training-time savings, with detailed visualizations illustrating cluster-level survival patterns and potential clinical insights.

Abstract

Kernel survival analysis models estimate individual survival distributions with the help of a kernel function, which measures the similarity between any two data points. Such a kernel function can be learned using deep kernel survival models. In this paper, we present a new deep kernel survival model called a survival kernet, which scales to large datasets in a manner that is amenable to model interpretation and also theoretical analysis. Specifically, the training data are partitioned into clusters based on a recently developed training set compression scheme for classification and regression called kernel netting that we extend to the survival analysis setting. At test time, each data point is represented as a weighted combination of these clusters, and each such cluster can be visualized. For a special case of survival kernets, we establish a finite-sample error bound on predicted survival distributions that is, up to a log factor, optimal. Whereas scalability at test time is achieved using the aforementioned kernel netting compression strategy, scalability during training is achieved by a warm-start procedure based on tree ensembles such as XGBoost and a heuristic approach to accelerating neural architecture search. On four standard survival analysis datasets of varying sizes (up to roughly 3 million data points), we show that survival kernets are highly competitive compared to various baselines tested in terms of time-dependent concordance index. Our code is available at: https://github.com/georgehc/survival-kernets
Paper Structure (73 sections, 10 theorems, 128 equations, 11 figures, 8 tables)

This paper contains 73 sections, 10 theorems, 128 equations, 11 figures, 8 tables.

Key Result

Theorem 5

Suppose that Assumptions $\mathbf{A}^{\text{technical}}$, $\mathbf{A}^{\text{intrinsic}}$, and $\mathbf{A}^{\text{survival}}$ hold, and we train a survival kernet with $\varepsilon=\beta\tau$ when constructing $\widetilde{\mathcal{Q}}_{\varepsilon}$, where $\beta\in(0,1)$ and $\tau>0$ are user-speci where $\widetilde{\mathcal{O}}$ ignores log factors.

Figures (11)

  • Figure 1: Visualization of the largest 5 clusters found by a survival kernet model trained on the support dataset (we limit the number of clusters shown for ease of exposition and to prevent the plots from being too cluttered); more information on how the model is trained is in Section \ref{['sec:experiments']}. Panel (a) shows a heatmap visualization that readily provides information on how the clusters are different, highlighting feature values that are prominent for specific clusters; the dotted horizontal lines separate features that correspond to the same underlying variable. Panel (b) shows Kaplan-Meier survival curves with 95% confidence intervals for the same clusters as in panel (a); the x-axis measures the number of days since a patient entered the study.
  • Figure 2: Survival curves for the largest 5 clusters found by the final tuna-kernet (no split, sft) model trained on the support dataset; the x-axis measures the number of days since a patient entered the study. Note that the green curve has a higher median survival time estimate than the red curve; this is not a typo in that we are ordering the clusters the exact same way as in Figure \ref{['fig:support']}\ref{['subfig:support-km']}. In particular, the median survival time estimates using summary fine-tuning do not have to be ordered the same way as the median survival time estimates from the Kaplan-Meier estimator.
  • Figure 3: Visualization of 10 superclusters for the final tuna-kernet (no split, sft) model trained on the support dataset. These 10 superclusters summarize all 73 clusters found by the tuna-kernet (no split, sft) model by merging clusters using complete-linkage agglomerative clustering. Panel (a) shows a feature heatmap visualization. Panel (b) shows survival curves for the same superclusters as in panel (a); the x-axis measures the number of days since a patient entered the study. The second-to-rightmost and third-to-rightmost columns/clusters each only have one data point, suggesting that they are outliers.
  • Figure 4: Visualization of all 10 clusters found by the final tuna-kernet (no split, sft) model trained on the rotterdam/gbsg dataset (technically trained only on the Rotterdam portion of the data). Panel (a) shows a heatmap visualization that readily provides information on how the clusters are different, highlighting feature values that are prominent for specific clusters; the dotted horizontal lines separate features that correspond to the same underlying variable. Panel (b) shows survival curves (estimated from learned summary functions) for the same clusters as in panel (a); the x-axis indicates recurrence free survival time in months.
  • Figure 5: Visualization of the largest 5 clusters found by the final tuna-kernet (no split, sft) model trained on the unos dataset. Panel (a) shows a heatmap visualization that readily provides information on how the clusters are different, highlighting feature values that are prominent for specific clusters; the dotted horizontal lines separate features that correspond to the same underlying variable. Panel (b) shows survival curves (estimated from learned summary functions) for the same clusters as in panel (a); the x-axis measures the number of years since a patient received a heart transplant.
  • ...and 6 more figures

Theorems & Definitions (14)

  • Definition 1
  • Definition 2
  • Claim 3: Follows from Corollary 4.2.13 by vershynin2018high
  • Claim 4
  • Theorem 5
  • Proposition 6
  • Lemma 7
  • Lemma 8
  • Lemma 9
  • Lemma 10
  • ...and 4 more