Table of Contents
Fetching ...

Gradient descent inference in empirical risk minimization

Qiyang Han, Xiaocong Xu

TL;DR

This work develops a non-asymptotic, joint distributional theory for gradient descent in empirical risk minimization within the mean-field regime where $m$ scales with $n$. It introduces a debiased gradient-descent framework, leveraging Onsager correction matrices and state evolution to yield approximately normal gradient iterates after a debiasing step, applicable to broad convex and non-convex losses and non-Gaussian data. A gradient-descent inference algorithm provides data-driven estimates of the Onsager matrices, enabling iteration-wise inference and generalization-error estimation without requiring convergence to the ERM. The theory is instantiated in single-index regression and generalized logistic regression, with extensive numerical experiments validating accurate confidence intervals, robust generalization estimates, and practical computational advantages over alternatives like LOOCV. Overall, the paper offers a practical, iteration-aware inference paradigm that couples algorithmic gradient descent with principled statistical debiasing and error estimation in high-dimensional settings.

Abstract

Gradient descent is one of the most widely used iterative algorithms in modern statistical learning. However, its precise algorithmic dynamics in high-dimensional settings remain only partially understood, which has limited its broader potential for statistical inference applications. This paper provides a precise, non-asymptotic joint distributional characterization of gradient descent iterates and their debiased statistics in a broad class of empirical risk minimization problems, in the so-called mean-field regime where the sample size is proportional to the signal dimension. Our non-asymptotic state evolution theory holds for both general non-convex loss functions and non-Gaussian data, and reveals the central role of two Onsager correction matrices that precisely characterize the non-trivial dependence among all gradient descent iterates in the mean-field regime. Leveraging the joint state evolution characterization, we show that the gradient descent iterate retrieves approximate normality after a debiasing correction via a linear combination of all past iterates, where the debiasing coefficients can be estimated by the proposed gradient descent inference algorithm. This leads to a new algorithmic statistical inference framework based on debiased gradient descent, which (i) applies to a broad class of models with both convex and non-convex losses, (ii) remains valid at each iteration without requiring algorithmic convergence, and (iii) exhibits a certain robustness to possible model misspecification. As a by-product, our framework also provides algorithmic estimates of the generalization error at each iteration. As canonical examples, we demonstrate our theory and inference methods in the single-index regression model and a generalized logistic regression model, where the natural loss functions may exhibit arbitrarily non-convex landscapes.

Gradient descent inference in empirical risk minimization

TL;DR

This work develops a non-asymptotic, joint distributional theory for gradient descent in empirical risk minimization within the mean-field regime where scales with . It introduces a debiased gradient-descent framework, leveraging Onsager correction matrices and state evolution to yield approximately normal gradient iterates after a debiasing step, applicable to broad convex and non-convex losses and non-Gaussian data. A gradient-descent inference algorithm provides data-driven estimates of the Onsager matrices, enabling iteration-wise inference and generalization-error estimation without requiring convergence to the ERM. The theory is instantiated in single-index regression and generalized logistic regression, with extensive numerical experiments validating accurate confidence intervals, robust generalization estimates, and practical computational advantages over alternatives like LOOCV. Overall, the paper offers a practical, iteration-aware inference paradigm that couples algorithmic gradient descent with principled statistical debiasing and error estimation in high-dimensional settings.

Abstract

Gradient descent is one of the most widely used iterative algorithms in modern statistical learning. However, its precise algorithmic dynamics in high-dimensional settings remain only partially understood, which has limited its broader potential for statistical inference applications. This paper provides a precise, non-asymptotic joint distributional characterization of gradient descent iterates and their debiased statistics in a broad class of empirical risk minimization problems, in the so-called mean-field regime where the sample size is proportional to the signal dimension. Our non-asymptotic state evolution theory holds for both general non-convex loss functions and non-Gaussian data, and reveals the central role of two Onsager correction matrices that precisely characterize the non-trivial dependence among all gradient descent iterates in the mean-field regime. Leveraging the joint state evolution characterization, we show that the gradient descent iterate retrieves approximate normality after a debiasing correction via a linear combination of all past iterates, where the debiasing coefficients can be estimated by the proposed gradient descent inference algorithm. This leads to a new algorithmic statistical inference framework based on debiased gradient descent, which (i) applies to a broad class of models with both convex and non-convex losses, (ii) remains valid at each iteration without requiring algorithmic convergence, and (iii) exhibits a certain robustness to possible model misspecification. As a by-product, our framework also provides algorithmic estimates of the generalization error at each iteration. As canonical examples, we demonstrate our theory and inference methods in the single-index regression model and a generalized logistic regression model, where the natural loss functions may exhibit arbitrarily non-convex landscapes.

Paper Structure

This paper contains 66 sections, 30 theorems, 275 equations, 8 figures, 1 table, 1 algorithm.

Key Result

Theorem 2.2

Suppose Assumption assump:setup holds for some $K,\Lambda \geq 2$. In both error estimates, the index $s$ in the brackets all run over $s \in [1:t]$.

Figures (8)

  • Figure 1: Linear regression. Top row: Squared loss. Bottom row: Pseudo-Huber loss.
  • Figure 2: Single-index regression model with squared loss. Top row: sigmoid link $\varphi_\ast(x) = 1/(1 + e^{-x})$. Bottom row: nonlinear link $\varphi_\ast(x) = x + \sin(x)$.
  • Figure 3: Logistic regression. Top row: Squared loss. Bottom row: Logistic loss.
  • Figure 4: Linear regression with $\ell_1$ penalty. Top row: Squared loss. Bottom row: Pseudo-Huber loss.
  • Figure 5: Single-index regression with $\ell_1$ penalty and squared loss. Top row: sigmoid link $\varphi_\ast(x) = 1/(1 + e^{-x})$. Bottom row: nonlinear link $\varphi_\ast(x) = x + \sin(x)$.
  • ...and 3 more figures

Theorems & Definitions (62)

  • Example 1.1: Single-index regression model
  • Example 1.2: Generalized logistic regression
  • Definition 2.1
  • Remark 1
  • Theorem 2.2
  • Theorem 2.3
  • Definition 2.4
  • Theorem 2.5
  • Remark 2
  • Remark 3
  • ...and 52 more