Table of Contents
Fetching ...

Mini-batch Estimation for Deep Cox Models: Statistical Foundations and Practical Guidance

Lang Zeng, Weijing Tang, Zhao Ren, Ying Ding

TL;DR

This work develops a theoretical and practical framework for mini-batch estimation in deep Cox models by introducing the mini-batch maximum partial-likelihood estimator (mb-MPLE). It proves mb-MPLE is consistent and attains minimax-optimal convergence rates for Cox-NN, with rates governed by intrinsic function complexity rather than ambient dimension, and shows for Cox regression that mb-MPLE is $\sqrt{n}$-consistent and asymptotically normal with batch-size dependent variance. The paper also provides actionable SGD guidance, demonstrating that the learning-rate-to-batch-size ratio critically shapes training dynamics and that larger batches improve local convexity and efficiency, validated through simulations and a large AREDS dataset analysis with a ResNet-50 backbone. Collectively, the results offer a rigorous statistical foundation and practical guidance for scalable survival analysis with deep learning, bridging theory and real-world applicability.

Abstract

The stochastic gradient descent (SGD) algorithm has been widely used to optimize deep Cox neural network (Cox-NN) by updating model parameters using mini-batches of data. We show that SGD aims to optimize the average of mini-batch partial-likelihood, which is different from the standard partial-likelihood. This distinction requires developing new statistical properties for the global optimizer, namely, the mini-batch maximum partial-likelihood estimator (mb-MPLE). We establish that mb-MPLE for Cox-NN is consistent and achieves the optimal minimax convergence rate up to a polylogarithmic factor. For Cox regression with linear covariate effects, we further show that mb-MPLE is $\sqrt{n}$-consistent and asymptotically normal with asymptotic variance approaching the information lower bound as batch size increases, which is confirmed by simulation studies. Additionally, we offer practical guidance on using SGD, supported by theoretical analysis and numerical evidence. For Cox-NN, we demonstrate that the ratio of the learning rate to the batch size is critical in SGD dynamics, offering insight into hyperparameter tuning. For Cox regression, we characterize the iterative convergence of SGD, ensuring that the global optimizer, mb-MPLE, can be approximated with sufficiently many iterations. Finally, we demonstrate the effectiveness of mb-MPLE in a large-scale real-world application where the standard MPLE is intractable.

Mini-batch Estimation for Deep Cox Models: Statistical Foundations and Practical Guidance

TL;DR

This work develops a theoretical and practical framework for mini-batch estimation in deep Cox models by introducing the mini-batch maximum partial-likelihood estimator (mb-MPLE). It proves mb-MPLE is consistent and attains minimax-optimal convergence rates for Cox-NN, with rates governed by intrinsic function complexity rather than ambient dimension, and shows for Cox regression that mb-MPLE is -consistent and asymptotically normal with batch-size dependent variance. The paper also provides actionable SGD guidance, demonstrating that the learning-rate-to-batch-size ratio critically shapes training dynamics and that larger batches improve local convexity and efficiency, validated through simulations and a large AREDS dataset analysis with a ResNet-50 backbone. Collectively, the results offer a rigorous statistical foundation and practical guidance for scalable survival analysis with deep learning, bridging theory and real-world applicability.

Abstract

The stochastic gradient descent (SGD) algorithm has been widely used to optimize deep Cox neural network (Cox-NN) by updating model parameters using mini-batches of data. We show that SGD aims to optimize the average of mini-batch partial-likelihood, which is different from the standard partial-likelihood. This distinction requires developing new statistical properties for the global optimizer, namely, the mini-batch maximum partial-likelihood estimator (mb-MPLE). We establish that mb-MPLE for Cox-NN is consistent and achieves the optimal minimax convergence rate up to a polylogarithmic factor. For Cox regression with linear covariate effects, we further show that mb-MPLE is -consistent and asymptotically normal with asymptotic variance approaching the information lower bound as batch size increases, which is confirmed by simulation studies. Additionally, we offer practical guidance on using SGD, supported by theoretical analysis and numerical evidence. For Cox-NN, we demonstrate that the ratio of the learning rate to the batch size is critical in SGD dynamics, offering insight into hyperparameter tuning. For Cox regression, we characterize the iterative convergence of SGD, ensuring that the global optimizer, mb-MPLE, can be approximated with sufficiently many iterations. Finally, we demonstrate the effectiveness of mb-MPLE in a large-scale real-world application where the standard MPLE is intractable.
Paper Structure (16 sections, 7 theorems, 31 equations, 5 figures)

This paper contains 16 sections, 7 theorems, 31 equations, 5 figures.

Key Result

Lemma 1

Let $L^{(s)}_{0}(f) := \mathbb{E}[L^{(s)}_{Cox}(f)] = \mathbb{E}_{D(n)}\left[\mathbb{E}\left [L^{(s)}_{Cox}(\theta)|D(n)\right ]\right]$. Under the Cox model, with assumptions (A1)-(A3) and (N1), for any integer $s \geq 2$ and constant $c>0$, we have for all $f\in \{f:\lVert f\rVert_\infty\leq c, \mathbb{E}[f(X)]=0\}$, where $d(f,f_0)=[\mathbb{E}\{f(X)-f_0(X)\}^2]^{\frac{1}{2}}$.

Figures (5)

  • Figure 1: (a) An illustrative picture showing the properties of ${\mathbb{E}} [L_{Cox}^{(s)}(\theta)]$ in Cox-regression: ${\mathbb{E}} [L_{Cox}^{(s)}(\theta)]$ reaches the minimum at $\theta_0$ regardless the choice of $s$ while its local convexity at $\theta_0$ increases when $s$ doubles. (b) Estimated ${\mathbb{E}} [\nabla_\theta L_{Cox}^{(s)}(\theta)]$ at a neighborhood of $\theta_0=1$ with different batch sizes $s$. Each estimation is based on 20,000 realizations of the mini-batch data consisting of $s$ i.i.d. samples generated from a Cox model with $f_0(X) = \theta_0X$.
  • Figure 2: Boxplots of $\log(\lVert \hat{\theta}-\theta_0\rVert_2^2)$ over $200$ runs with sample size $n=2,048$ where $\hat{\theta}$ is solved by four different methods. SGD with two different batch sampling strategies is considered, either with a fixed batch sampling strategy (FB) or with a stochastic batch sampling strategy (SB). CoxPH-strata is a stratified Cox model treating the fixed batches from SGD-FB as strata and serves as the global minimizer for SGD-FB.
  • Figure 3: The negative log-partial likelihood $L_{Cox}^{(N_{test})}(\theta)$ evaluated on a test data ($N_{test}= 2,048$) over the training epochs. The learning rate $\gamma$ is 0.1/16 when the batch size is 32. $\gamma$ is doubled when doubling the batch size. All the other hyperparameters are kept the same.
  • Figure 4: First panel: the structure of the Cox-NN model optimized by SGD to predict the time-to-AMD progression based on the fundus image and demographics; second panel: the required memory and running time of SGD over different batch sizes; third panel: the C-index (on the test data) over training epochs under the choice of different batch sizes.
  • Figure 5: The C-index on test data of Cox-NN over training epochs in the AREDS application. The Cox-NN is optimized by SGD with different choices of batch size and learning rate. All the other hyperparameters are fixed.

Theorems & Definitions (20)

  • Lemma 1
  • Remark 1
  • Theorem 1
  • Remark 2
  • Remark 3
  • Remark 4
  • Theorem 2
  • Remark 5
  • Proposition 1
  • Theorem 3
  • ...and 10 more