Table of Contents
Fetching ...

Efficient Training of Probabilistic Neural Networks for Survival Analysis

Christian Marius Lillelund, Martin Magris, Christian Fischer Pedersen

TL;DR

This work investigates how to train deep probabilistic survival models in large datasets without introducing additional overhead in model complexity, and adopts three probabilistic approaches, namely VI, MCD, and SNGP, and evaluates them in terms of their prediction performance, calibration performance, and model complexity.

Abstract

Variational Inference (VI) is a commonly used technique for approximate Bayesian inference and uncertainty estimation in deep learning models, yet it comes at a computational cost, as it doubles the number of trainable parameters to represent uncertainty. This rapidly becomes challenging in high-dimensional settings and motivates the use of alternative techniques for inference, such as Monte Carlo Dropout (MCD) or Spectral-normalized Neural Gaussian Process (SNGP). However, such methods have seen little adoption in survival analysis, and VI remains the prevalent approach for training probabilistic neural networks. In this paper, we investigate how to train deep probabilistic survival models in large datasets without introducing additional overhead in model complexity. To achieve this, we adopt three probabilistic approaches, namely VI, MCD, and SNGP, and evaluate them in terms of their prediction performance, calibration performance, and model complexity. In the context of probabilistic survival analysis, we investigate whether non-VI techniques can offer comparable or possibly improved prediction performance and uncertainty calibration compared to VI. In the MIMIC-IV dataset, we find that MCD aligns with VI in terms of the concordance index (0.748 vs. 0.743) and mean absolute error (254.9 vs. 254.7) using hinge loss, while providing C-calibrated uncertainty estimates. Moreover, our SNGP implementation provides D-calibrated survival functions in all datasets compared to VI (4/4 vs. 2/4, respectively). Our work encourages the use of techniques alternative to VI for survival analysis in high-dimensional datasets, where computational efficiency and overhead are of concern.

Efficient Training of Probabilistic Neural Networks for Survival Analysis

TL;DR

This work investigates how to train deep probabilistic survival models in large datasets without introducing additional overhead in model complexity, and adopts three probabilistic approaches, namely VI, MCD, and SNGP, and evaluates them in terms of their prediction performance, calibration performance, and model complexity.

Abstract

Variational Inference (VI) is a commonly used technique for approximate Bayesian inference and uncertainty estimation in deep learning models, yet it comes at a computational cost, as it doubles the number of trainable parameters to represent uncertainty. This rapidly becomes challenging in high-dimensional settings and motivates the use of alternative techniques for inference, such as Monte Carlo Dropout (MCD) or Spectral-normalized Neural Gaussian Process (SNGP). However, such methods have seen little adoption in survival analysis, and VI remains the prevalent approach for training probabilistic neural networks. In this paper, we investigate how to train deep probabilistic survival models in large datasets without introducing additional overhead in model complexity. To achieve this, we adopt three probabilistic approaches, namely VI, MCD, and SNGP, and evaluate them in terms of their prediction performance, calibration performance, and model complexity. In the context of probabilistic survival analysis, we investigate whether non-VI techniques can offer comparable or possibly improved prediction performance and uncertainty calibration compared to VI. In the MIMIC-IV dataset, we find that MCD aligns with VI in terms of the concordance index (0.748 vs. 0.743) and mean absolute error (254.9 vs. 254.7) using hinge loss, while providing C-calibrated uncertainty estimates. Moreover, our SNGP implementation provides D-calibrated survival functions in all datasets compared to VI (4/4 vs. 2/4, respectively). Our work encourages the use of techniques alternative to VI for survival analysis in high-dimensional datasets, where computational efficiency and overhead are of concern.
Paper Structure (16 sections, 13 equations, 2 figures, 5 tables)

This paper contains 16 sections, 13 equations, 2 figures, 5 tables.

Figures (2)

  • Figure 1: The proposed architectures for computing $\hat{y}_{i}$, i.e., the predicted survival time for individual $i$. (a): BNN with three hidden layers trained with VI, where $\bm{x}_i \in \mathbb{R}^4$ is a four-dimensional covariate vector. (b): MLP with three hidden layers trained with MCD. Transparent objects emphasize the random activation of the nodes under MCD. (c): MLP using the SNGP technique, with spectral normalized hidden layers and the output computed from a GP layer.
  • Figure 2: Model prediction on the METABRIC dataset using MCD and VI. For an individual $i$, we use the median of the predicted survival function as the predicted survival time $\hat{y}_i$. Although MCD and VI predict similar median survival times, VI has more variance around this particular prediction than MCD. Left column: mean individual survival function and its corresponding 90% CrI. Right column: histogram (as an approximation of the underlying density) of the predicted survival time.