Table of Contents
Fetching ...

Tracking the Median of Gradients with a Stochastic Proximal Point Method

Fabian Schaipp, Guillaume Garrigos, Umut Simsekli, Robert Gower

TL;DR

This work introduces a stochastic proximal point (SPP) framework to track the median gradient online, providing a robust alternative to SGD when gradient noise is heavy-tailed or corrupted. By deriving online median estimators via SPP, the authors connect clipping-based methods to median estimation and momentum to mean estimation, establishing a $1/\sqrt{T}$ convergence rate under challenging noise conditions. They prove theoretical guarantees for the Sample Median Gradient Descent under multi-sample settings and validate the approach with synthetic and language-modeling experiments, showing median-based online estimators can outperform mean-based SGD in heavy-tailed regimes. The results offer a principled bridge between robust statistics and stochastic optimization, with practical implications for distributed learning and privacy-constrained settings where outliers and heavy tails are common.

Abstract

There are several applications of stochastic optimization where one can benefit from a robust estimate of the gradient. For example, domains such as distributed learning with corrupted nodes, the presence of large outliers in the training data, learning under privacy constraints, or even heavy-tailed noise due to the dynamics of the algorithm itself. Here we study SGD with robust gradient estimators based on estimating the median. We first derive iterative methods based on the stochastic proximal point method for computing the median gradient and generalizations thereof. Then we propose an algorithm estimating the median gradient across iterations, and find that several well known methods are particular cases of this framework. For instance, we observe that different forms of clipping allow to compute online estimators of the median of gradients, in contrast to (heavy-ball) momentum, which corresponds to an online estimator of the mean. Finally, we provide a theoretical framework for an algorithm computing the median gradient across samples, and show that the resulting method can converge even under heavy-tailed, state-dependent noise.

Tracking the Median of Gradients with a Stochastic Proximal Point Method

TL;DR

This work introduces a stochastic proximal point (SPP) framework to track the median gradient online, providing a robust alternative to SGD when gradient noise is heavy-tailed or corrupted. By deriving online median estimators via SPP, the authors connect clipping-based methods to median estimation and momentum to mean estimation, establishing a convergence rate under challenging noise conditions. They prove theoretical guarantees for the Sample Median Gradient Descent under multi-sample settings and validate the approach with synthetic and language-modeling experiments, showing median-based online estimators can outperform mean-based SGD in heavy-tailed regimes. The results offer a principled bridge between robust statistics and stochastic optimization, with practical implications for distributed learning and privacy-constrained settings where outliers and heavy tails are common.

Abstract

There are several applications of stochastic optimization where one can benefit from a robust estimate of the gradient. For example, domains such as distributed learning with corrupted nodes, the presence of large outliers in the training data, learning under privacy constraints, or even heavy-tailed noise due to the dynamics of the algorithm itself. Here we study SGD with robust gradient estimators based on estimating the median. We first derive iterative methods based on the stochastic proximal point method for computing the median gradient and generalizations thereof. Then we propose an algorithm estimating the median gradient across iterations, and find that several well known methods are particular cases of this framework. For instance, we observe that different forms of clipping allow to compute online estimators of the median of gradients, in contrast to (heavy-ball) momentum, which corresponds to an online estimator of the mean. Finally, we provide a theoretical framework for an algorithm computing the median gradient across samples, and show that the resulting method can converge even under heavy-tailed, state-dependent noise.
Paper Structure (48 sections, 22 theorems, 111 equations, 10 figures, 1 table)

This paper contains 48 sections, 22 theorems, 111 equations, 10 figures, 1 table.

Key Result

corollary 0

For $\mathcal{D} = \|\cdot\|_2$ update eqn:online-spp-update is given by where $\mathrm{clip}_{\tau,2}(\bm{v}) := \frac{\tau }{\max\{\tau, \|\bm{v}\|_2 \}} \bm{v}$.

Figures (10)

  • Figure 1: First and second moment of sample median/mean for standard $\alpha$-stable distribution. We approximate $\mathbb{E}$ by running $10\,000$ trials. The sample median has much smaller variance (cf. \ref{['eqn:sample-median-noise-cond']}) compared to the sample mean.
  • Figure 2: $\tau=0.01$Left: Final error for varying values of $\alpha$ (from left to right, distributions are more heavy-tailed). Shaded area marks minimal and maximal value over the $50$ independent runs. Right: Convergence plot for all methods for $\tfrac{1}{\alpha}\in[0.5,1.5]$ (higher value of $\tfrac{1}{\alpha}$ corresponds to heavier tails).
  • Figure 3: Least squares with heavy-tailed noise that is independent over components (left), state-dependent (middle), and where components of noise are dependent (right). All methods in gray use $n=5$ samples per iteration, all others only one. Note that for \ref{['item:independent-state']}, the sample mean immediately diverges, and hence does not appear in the plot. Shaded area depicts minimum and maximum value over 50 repetitions.
  • Figure 4: Training loss for each method, with tuned learning rate. Shaded area depicts minimum and maximum value over 5 seeds.
  • Figure 5: Validation score (measured as perplexity, or F1 score) for each method, with tuned learning rate. Shaded area depicts minimum and maximum value over 5 seeds.
  • ...and 5 more figures

Theorems & Definitions (38)

  • corollary 0
  • corollary 0
  • corollary 0
  • proposition 2
  • proposition 2
  • corollary 2
  • lemma 3
  • proof
  • lemma 4
  • proof
  • ...and 28 more