Table of Contents
Fetching ...

HiGrad: Uncertainty Quantification for Online Learning and Stochastic Approximation

Weijie J. Su, Yuancheng Zhu

TL;DR

HiGrad develops a statistically principled framework for uncertainty quantification in online learning by augmenting SGD with a hierarchical tree that yields multiple correlated estimators. It constructs a $t$-based confidence interval for $\mu_x(\theta^*)$ using a decorrelated, Ruppert--Polyak-inspired covariance structure and proves asymptotically exact coverage under standard convexity and regularity conditions. The method achieves the same asymptotic variance as vanilla averaged SGD and, with optimally chosen weights, can attain the Cramér–Rao lower bound under model specification. Extensions cover flexible tree structures, mini-batch updates, multivariate targets, and practical tricks like burn-in and restarting, with extensive simulations and a real-data example demonstrating robust finite-sample performance and a publicly available R package, higrad.

Abstract

Stochastic gradient descent (SGD) is an immensely popular approach for online learning in settings where data arrives in a stream or data sizes are very large. However, despite an ever-increasing volume of work on SGD, much less is known about the statistical inferential properties of SGD-based predictions. Taking a fully inferential viewpoint, this paper introduces a novel procedure termed HiGrad to conduct statistical inference for online learning, without incurring additional computational cost compared with SGD. The HiGrad procedure begins by performing SGD updates for a while and then splits the single thread into several threads, and this procedure hierarchically operates in this fashion along each thread. With predictions provided by multiple threads in place, a $t$-based confidence interval is constructed by decorrelating predictions using covariance structures given by a Donsker-style extension of the Ruppert--Polyak averaging scheme, which is a technical contribution of independent interest. Under certain regularity conditions, the HiGrad confidence interval is shown to attain asymptotically exact coverage probability. Finally, the performance of HiGrad is evaluated through extensive simulation studies and a real data example. An R package \texttt{higrad} has been developed to implement the method.

HiGrad: Uncertainty Quantification for Online Learning and Stochastic Approximation

TL;DR

HiGrad develops a statistically principled framework for uncertainty quantification in online learning by augmenting SGD with a hierarchical tree that yields multiple correlated estimators. It constructs a -based confidence interval for using a decorrelated, Ruppert--Polyak-inspired covariance structure and proves asymptotically exact coverage under standard convexity and regularity conditions. The method achieves the same asymptotic variance as vanilla averaged SGD and, with optimally chosen weights, can attain the Cramér–Rao lower bound under model specification. Extensions cover flexible tree structures, mini-batch updates, multivariate targets, and practical tricks like burn-in and restarting, with extensive simulations and a real-data example demonstrating robust finite-sample performance and a publicly available R package, higrad.

Abstract

Stochastic gradient descent (SGD) is an immensely popular approach for online learning in settings where data arrives in a stream or data sizes are very large. However, despite an ever-increasing volume of work on SGD, much less is known about the statistical inferential properties of SGD-based predictions. Taking a fully inferential viewpoint, this paper introduces a novel procedure termed HiGrad to conduct statistical inference for online learning, without incurring additional computational cost compared with SGD. The HiGrad procedure begins by performing SGD updates for a while and then splits the single thread into several threads, and this procedure hierarchically operates in this fashion along each thread. With predictions provided by multiple threads in place, a -based confidence interval is constructed by decorrelating predictions using covariance structures given by a Donsker-style extension of the Ruppert--Polyak averaging scheme, which is a technical contribution of independent interest. Under certain regularity conditions, the HiGrad confidence interval is shown to attain asymptotically exact coverage probability. Finally, the performance of HiGrad is evaluated through extensive simulation studies and a real data example. An R package \texttt{higrad} has been developed to implement the method.

Paper Structure

This paper contains 38 sections, 19 theorems, 201 equations, 8 figures, 2 tables, 1 algorithm.

Key Result

Proposition 1

Let $K$ and $B_1, \ldots, B_K$ be fixed. For each $k$, assume $n_k/N$ converges to a nonzero constant as $N \rightarrow \infty$. Under Assumptions ass:cvx and ass:reg, taking step sizes $\gamma_j = \frac{c_1}{(j + c_2)^{\alpha}}$ for fixed $\alpha \in (0.5, 1), c_1 > 0$ and $c_2$ ensures the followi

Figures (8)

  • Figure 1: Length of 90% empirical prediction intervals versus average predicted probabilities on a test set of size 1,000 from the Adult dataset, calculated based on 500 independent SGD runs, each with 25 epochs.
  • Figure 2: Graphical illustration of the HiGrad tree. Here we have three levels. At the end of the first level, the segment is split into two; at the end of the second level, each segment is further split into three. There are six threads in this HiGrad, each defined as a path from the root node to one of the six leaf nodes.
  • Figure 3: Graphical illustration of the HiGrad algorithm. Here we have three levels and at the end of each level, each segment is split into two segments. Averages are obtained for each level and at each leaf a weighted average is calculated. The weights $w_j$ are detailed in Section \ref{['sec:decorr-thre']}, and more discussion about the tree structure is given in Section \ref{['sec:how-split']}.
  • Figure 4: Rescaled expected length of confidence intervals versus $T$, the number of HiGrad threads. The left plot and right plot correspond to $\alpha = 0.05$ and $\alpha = 0.1$, respectively. The gray dashed lines indicate the confidence interval lengths at $T = \infty$.
  • Figure 5: Estimation accuracy of HiGrad against the total number of iteration steps. The risk is averaged over 100 replicates and is further normalized by that of vanilla SGD. The four HiGrad configurations are described in Table \ref{['table:configs']}.
  • ...and 3 more figures

Theorems & Definitions (22)

  • Proposition 1
  • Remark 2
  • Theorem 3: Confidence intervals
  • Theorem 4: Prediction intervals
  • Remark 5
  • Proposition 6
  • Theorem 7: Prediction intervals for vanilla SGD
  • Lemma 8
  • Lemma 9
  • Proposition 10
  • ...and 12 more