Table of Contents
Fetching ...

Per-example gradients: a new frontier for understanding and improving optimizers

Vincent Roulet, Atish Agarwala

TL;DR

This work makes per-example gradient statistics tractable at scale by using computational-graph surgery and JAX/vmap prototyping to access quantities like $\mathbb{E}[\nabla f(\theta; X)]$, $\mathbb{E}[\nabla f(\theta; X)^2]$, and $\mathbb{E}[\operatorname{sign}(\nabla f(\theta; X))]$. It demonstrates that non-linear gradient averages can be computed with minimal overhead in many architectures and shows how these statistics can illuminate optimizer behavior. Key findings include the optimal late placement of the sign in SignSGD (SignEMA) for stability, and that Adam’s preconditioner benefits from being dominated by $\mu^2$ rather than $\sigma^2$, motivating new Adam variants (MicroAdamVar, MicroAdamMSQ). Overall, the work expands the optimization-design space by enabling analysis and development of per-example gradient transformations, with practical demonstrations on transformer-scale models and implications for training stability and scaling.

Abstract

Training algorithms in deep learning usually treat a mini-batch of samples as a single object; they average gradients over the mini-batch, and then process the average in various ways. Computing other statistics beyond the average may have been seen as prohibitively resource intensive in automatic differentiation (AD) frameworks. We show that this is not the case. Generally, gradient statistics can be implemented through a surgery of the AD graph, which, in some cases, incur almost no computational and memory overheads compared to the mini-batch gradient computation. Additionally, we show that in certain classes of models, including transformers, JAX's vectorization transformation offers a viable implementation for prototyping and experimentation. We then revise our understanding of two nonlinear operations in optimization through the lens of per-example gradient transformations. We first study signSGD and show that the optimal placement of the sign operation in the gradient processing chain is crucial to success and can be predicted with a simple signal-to-noise ratio argument. Next we study per-example variations of the Adam preconditioner, and show that optimization is best served when the preconditioner is dominated by the mean rather than the variance of the gradient distribution - in contrast to conventional wisdom. Overall we demonstrate that per-example gradient information enables new analyses and possibilities for algorithm design.

Per-example gradients: a new frontier for understanding and improving optimizers

TL;DR

This work makes per-example gradient statistics tractable at scale by using computational-graph surgery and JAX/vmap prototyping to access quantities like , , and . It demonstrates that non-linear gradient averages can be computed with minimal overhead in many architectures and shows how these statistics can illuminate optimizer behavior. Key findings include the optimal late placement of the sign in SignSGD (SignEMA) for stability, and that Adam’s preconditioner benefits from being dominated by rather than , motivating new Adam variants (MicroAdamVar, MicroAdamMSQ). Overall, the work expands the optimization-design space by enabling analysis and development of per-example gradient transformations, with practical demonstrations on transformer-scale models and implications for training stability and scaling.

Abstract

Training algorithms in deep learning usually treat a mini-batch of samples as a single object; they average gradients over the mini-batch, and then process the average in various ways. Computing other statistics beyond the average may have been seen as prohibitively resource intensive in automatic differentiation (AD) frameworks. We show that this is not the case. Generally, gradient statistics can be implemented through a surgery of the AD graph, which, in some cases, incur almost no computational and memory overheads compared to the mini-batch gradient computation. Additionally, we show that in certain classes of models, including transformers, JAX's vectorization transformation offers a viable implementation for prototyping and experimentation. We then revise our understanding of two nonlinear operations in optimization through the lens of per-example gradient transformations. We first study signSGD and show that the optimal placement of the sign operation in the gradient processing chain is crucial to success and can be predicted with a simple signal-to-noise ratio argument. Next we study per-example variations of the Adam preconditioner, and show that optimization is best served when the preconditioner is dominated by the mean rather than the variance of the gradient distribution - in contrast to conventional wisdom. Overall we demonstrate that per-example gradient information enables new analyses and possibilities for algorithm design.

Paper Structure

This paper contains 42 sections, 46 equations, 10 figures, 1 algorithm.

Figures (10)

  • Figure 1: Computational graph of the mini-batch loss and the mini-batch gradient w.r.t. some intermediate weights. The forward pass has independent computational paths for each datapoint $x_{i}$ and intermediate activation $s_{i}$, and the weights are essentially broadcasted to each computational path before the final loss merges them (left). In the backwards pass the residuals $r_{i}$ move along the reversed computational paths and are similarly broadcast, and the merging of paths only happens at the end via sum reduction --- the adjoint operation of the weight broadcasting. Computing gradient statistics of a function $\phi$ of the gradients can be done by injecting $\phi$ just before the final sum reduction.
  • Figure 2: Memory footprint along program execution as reported by a code profiler of the train step of the usual Adam algorithm and its per-example variant, MicroAdam (Section \ref{['sec:microadam']}), for a 1.2B transformer in Nanodo nanodo. The peak memory corresponds to the accumulation of the memory during the forward pass of automatic differentiation necessary to compute gradients. We observe that the per-example variant may incur more operations (longer tail) that translate in longer execution time. But the peak memory is the same.
  • Figure 3: Learning curves for SignSGD variants at optimal learning rates, $\beta_{1} = 0.9$. SignEMA has the best performance and MicroSignSGD has the worst performance. This suggests that the sign function needs to be applied as late as possible to prevent signal-to-noise ratio reduction for gradients of individual parameters.
  • Figure 4: MicroAdam (orange) emphasizes variance information in preconditioner and generally trains less stably and more slowly than Adam (blue), while MicroAdamMSQ (green) emphasizes mean squared information and shows slight gains
  • Figure 5: Adam family variants trained at various batch sizes with their respective learning rate scaling rules. Adam (top left) is trained with $\eta\propto\sqrt{B}$ and shows universal loss curves for intermediate batch size, but not for small or large batch size. MicroAdam (top right), MicroAdamMSQ (bottom left), and MicroAdamVar (bottom right) all show universal scaling at small and intermediate batch sizes with $\eta\propto B$. Adam family members with more $\sigma^{2}$ contribution to preconditioner suffer from stability issues, particularly MicroAdamVar which only depends on $\sigma^{2}$. Learning rates are chosen to be close to optimal at $B = 64$.
  • ...and 5 more figures

Theorems & Definitions (1)

  • proof