Table of Contents
Fetching ...

On the Relation Between the Sharpest Directions of DNN Loss and the SGD Step Length

Stanisław Jastrzębski, Zachary Kenton, Nicolas Ballas, Asja Fischer, Yoshua Bengio, Amos Storkey

TL;DR

This work investigates how SGD interacts with the sharpest Hessian directions during full neural network training. It shows that the top curvature grows early and peaks at a level determined by the learning rate and batch size, and that SGD steps aligned with the sharpest directions are often too large to minimize in that subspace. By analyzing gradients projected onto sharpest directions, the authors reveal that optimization along these directions is ineffective, yet the alignment remains strong, motivating a targeted variant. They introduce Nudged-SGD, which reduces the learning rate along the sharpest directions to probe faster optimization and potentially sharper, better-generalizing minima, with results that depend on model and dataset. Overall, the paper highlights how curvature in the sharpest directions guides training dynamics and generalization, offering a framework for designing optimizers tailored to the loss surface in deep networks.

Abstract

Stochastic Gradient Descent (SGD) based training of neural networks with a large learning rate or a small batch-size typically ends in well-generalizing, flat regions of the weight space, as indicated by small eigenvalues of the Hessian of the training loss. However, the curvature along the SGD trajectory is poorly understood. An empirical investigation shows that initially SGD visits increasingly sharp regions, reaching a maximum sharpness determined by both the learning rate and the batch-size of SGD. When studying the SGD dynamics in relation to the sharpest directions in this initial phase, we find that the SGD step is large compared to the curvature and commonly fails to minimize the loss along the sharpest directions. Furthermore, using a reduced learning rate along these directions can improve training speed while leading to both sharper and better generalizing solutions compared to vanilla SGD. In summary, our analysis of the dynamics of SGD in the subspace of the sharpest directions shows that they influence the regions that SGD steers to (where larger learning rate or smaller batch size result in wider regions visited), the overall training speed, and the generalization ability of the final model.

On the Relation Between the Sharpest Directions of DNN Loss and the SGD Step Length

TL;DR

This work investigates how SGD interacts with the sharpest Hessian directions during full neural network training. It shows that the top curvature grows early and peaks at a level determined by the learning rate and batch size, and that SGD steps aligned with the sharpest directions are often too large to minimize in that subspace. By analyzing gradients projected onto sharpest directions, the authors reveal that optimization along these directions is ineffective, yet the alignment remains strong, motivating a targeted variant. They introduce Nudged-SGD, which reduces the learning rate along the sharpest directions to probe faster optimization and potentially sharper, better-generalizing minima, with results that depend on model and dataset. Overall, the paper highlights how curvature in the sharpest directions guides training dynamics and generalization, offering a framework for designing optimizers tailored to the loss surface in deep networks.

Abstract

Stochastic Gradient Descent (SGD) based training of neural networks with a large learning rate or a small batch-size typically ends in well-generalizing, flat regions of the weight space, as indicated by small eigenvalues of the Hessian of the training loss. However, the curvature along the SGD trajectory is poorly understood. An empirical investigation shows that initially SGD visits increasingly sharp regions, reaching a maximum sharpness determined by both the learning rate and the batch-size of SGD. When studying the SGD dynamics in relation to the sharpest directions in this initial phase, we find that the SGD step is large compared to the curvature and commonly fails to minimize the loss along the sharpest directions. Furthermore, using a reduced learning rate along these directions can improve training speed while leading to both sharper and better generalizing solutions compared to vanilla SGD. In summary, our analysis of the dynamics of SGD in the subspace of the sharpest directions shows that they influence the regions that SGD steers to (where larger learning rate or smaller batch size result in wider regions visited), the overall training speed, and the generalization ability of the final model.

Paper Structure

This paper contains 28 sections, 1 equation, 22 figures, 7 tables.

Figures (22)

  • Figure 1: Left: Outline of the phenomena discussed in the paper. Curvature along the sharpest direction(s) initially grows (A to C). In most iterations, we find that SGD crosses the minimum if restricted to the subspace of the sharpest direction(s) by taking a too large step (B and C). Finally, curvature stabilizes or decays with a peak value determined by learning rate and batch size (C, see also right). Right two: Representative example of the evolution of the top $30$ (decreasing, red to blue) eigenvalues of the Hessian for a SimpleCNN model during training (with $\eta=0.005$, note that $\eta$ is close to $\frac{1}{\lambda_{max}} = \frac{1}{160}$).
  • Figure 2: Top: Evolution of the top $10$ eigenvalues of the Hessian for SimpleCNN and Resnet-32 trained on the CIFAR-10 dataset with $\eta=0.1$ and $S=128$. Bottom: Zoom on the evolution of the top $10$ eigenvalues in the beginning of training. A sharp initial growth of the largest eigenvalues followed by an oscillatory-like evolution is visible. Training and test accuracy of the corresponding models are provided for reference.
  • Figure 3: Full batch-size training of Resnet-32 for $\eta=0.01$ and (left) $\eta=0.05$ (right) on CIFAR-10. Evolution of the top $10$ eigenvalues of the Hessian and accuracy are shown in each case. The training is unstable; an initial growth of curvature scale is followed by a sudden drop in accuracy. The CIFAR-10 dataset was subsampled to $2500$ points.
  • Figure 4: Evolution of the two largest eigenvalues (solid and dashed line) of the Hessian for Resnet-32, SimpleCNN, and LSTM (trained on the PTB dataset) models on a log-scale for different learning rates (top) and batch-sizes (bottom). Blue/green/red correspond to increasing $\eta$ and decreasing $S$ in each figure. Left side of the vertical blue bar in each plot corresponds to the early phase of training. Larger learning rate or a smaller batch-size correlates with a smaller and earlier peak of the spectral norm and the next largest eigenvalue.
  • Figure 5: Average cosine between mini-batch gradient (y axis) and top sharpest directions (averaged over top $5$) for different $\eta$ (color) evaluated at different level of accuracies, during training (x axis). For comparison, the horizontal purple line is alignment with a random vector in the parameter space. Curves were smoothed for clarity.
  • ...and 17 more figures