Gradient descent inference in empirical risk minimization
Qiyang Han, Xiaocong Xu
TL;DR
This work develops a non-asymptotic, joint distributional theory for gradient descent in empirical risk minimization within the mean-field regime where $m$ scales with $n$. It introduces a debiased gradient-descent framework, leveraging Onsager correction matrices and state evolution to yield approximately normal gradient iterates after a debiasing step, applicable to broad convex and non-convex losses and non-Gaussian data. A gradient-descent inference algorithm provides data-driven estimates of the Onsager matrices, enabling iteration-wise inference and generalization-error estimation without requiring convergence to the ERM. The theory is instantiated in single-index regression and generalized logistic regression, with extensive numerical experiments validating accurate confidence intervals, robust generalization estimates, and practical computational advantages over alternatives like LOOCV. Overall, the paper offers a practical, iteration-aware inference paradigm that couples algorithmic gradient descent with principled statistical debiasing and error estimation in high-dimensional settings.
Abstract
Gradient descent is one of the most widely used iterative algorithms in modern statistical learning. However, its precise algorithmic dynamics in high-dimensional settings remain only partially understood, which has limited its broader potential for statistical inference applications. This paper provides a precise, non-asymptotic joint distributional characterization of gradient descent iterates and their debiased statistics in a broad class of empirical risk minimization problems, in the so-called mean-field regime where the sample size is proportional to the signal dimension. Our non-asymptotic state evolution theory holds for both general non-convex loss functions and non-Gaussian data, and reveals the central role of two Onsager correction matrices that precisely characterize the non-trivial dependence among all gradient descent iterates in the mean-field regime. Leveraging the joint state evolution characterization, we show that the gradient descent iterate retrieves approximate normality after a debiasing correction via a linear combination of all past iterates, where the debiasing coefficients can be estimated by the proposed gradient descent inference algorithm. This leads to a new algorithmic statistical inference framework based on debiased gradient descent, which (i) applies to a broad class of models with both convex and non-convex losses, (ii) remains valid at each iteration without requiring algorithmic convergence, and (iii) exhibits a certain robustness to possible model misspecification. As a by-product, our framework also provides algorithmic estimates of the generalization error at each iteration. As canonical examples, we demonstrate our theory and inference methods in the single-index regression model and a generalized logistic regression model, where the natural loss functions may exhibit arbitrarily non-convex landscapes.
