Table of Contents
Fetching ...

Training trajectories, mini-batch losses and the curious role of the learning rate

Mark Sandler, Andrey Zhmoginov, Max Vladymyrov, Nolan Miller

TL;DR

The work investigates how stochastic gradient descent traverses the loss landscape by examining the loss on fixed training batches along the trajectory, revealing convex-like, often quadratic behavior that enables rapid descent with large learning rates. It introduces a simple quadratic per-batch loss model and shows that various weight-averaging methods (EMA, SWA, two-point averaging) correspond to specific effective learning-rate schedules, a connection supported by large-scale ImageNet experiments with ResNet. The analysis links averaging to reduced LR and demonstrates how averaging moves trajectories inside an ellipsoid around the minimum, with two-point averaging and EMA providing notable improvements in accuracy, especially in multi-timescale dynamics. The study validates these insights across ImageNet and CIFAR, analyzes gradient alignment, and discusses basin geometry and open questions for extending the framework to full trajectories and trajectory-dependent LR strategies.

Abstract

Stochastic gradient descent plays a fundamental role in nearly all applications of deep learning. However its ability to converge to a global minimum remains shrouded in mystery. In this paper we propose to study the behavior of the loss function on fixed mini-batches along SGD trajectories. We show that the loss function on a fixed batch appears to be remarkably convex-like. In particular for ResNet the loss for any fixed mini-batch can be accurately modeled by a quadratic function and a very low loss value can be reached in just one step of gradient descent with sufficiently large learning rate. We propose a simple model that allows to analyze the relationship between the gradients of stochastic mini-batches and the full batch. Our analysis allows us to discover the equivalency between iterate aggregates and specific learning rate schedules. In particular, for Exponential Moving Average (EMA) and Stochastic Weight Averaging we show that our proposed model matches the observed training trajectories on ImageNet. Our theoretical model predicts that an even simpler averaging technique, averaging just two points a many steps apart, significantly improves accuracy compared to the baseline. We validated our findings on ImageNet and other datasets using ResNet architecture.

Training trajectories, mini-batch losses and the curious role of the learning rate

TL;DR

The work investigates how stochastic gradient descent traverses the loss landscape by examining the loss on fixed training batches along the trajectory, revealing convex-like, often quadratic behavior that enables rapid descent with large learning rates. It introduces a simple quadratic per-batch loss model and shows that various weight-averaging methods (EMA, SWA, two-point averaging) correspond to specific effective learning-rate schedules, a connection supported by large-scale ImageNet experiments with ResNet. The analysis links averaging to reduced LR and demonstrates how averaging moves trajectories inside an ellipsoid around the minimum, with two-point averaging and EMA providing notable improvements in accuracy, especially in multi-timescale dynamics. The study validates these insights across ImageNet and CIFAR, analyzes gradient alignment, and discusses basin geometry and open questions for extending the framework to full trajectories and trajectory-dependent LR strategies.

Abstract

Stochastic gradient descent plays a fundamental role in nearly all applications of deep learning. However its ability to converge to a global minimum remains shrouded in mystery. In this paper we propose to study the behavior of the loss function on fixed mini-batches along SGD trajectories. We show that the loss function on a fixed batch appears to be remarkably convex-like. In particular for ResNet the loss for any fixed mini-batch can be accurately modeled by a quadratic function and a very low loss value can be reached in just one step of gradient descent with sufficiently large learning rate. We propose a simple model that allows to analyze the relationship between the gradients of stochastic mini-batches and the full batch. Our analysis allows us to discover the equivalency between iterate aggregates and specific learning rate schedules. In particular, for Exponential Moving Average (EMA) and Stochastic Weight Averaging we show that our proposed model matches the observed training trajectories on ImageNet. Our theoretical model predicts that an even simpler averaging technique, averaging just two points a many steps apart, significantly improves accuracy compared to the baseline. We validated our findings on ImageNet and other datasets using ResNet architecture.
Paper Structure (29 sections, 2 theorems, 38 equations, 14 figures, 1 table)

This paper contains 29 sections, 2 theorems, 38 equations, 14 figures, 1 table.

Key Result

Lemma 1

If $E_{x \sim {\mathcal{D}}}\left[A^T_xA_x\right] = I$ the trajectory will stabilize at an ellipsoid of a fixed size proportional to the square root of the learning rate $\sqrt{\lambda}$.

Figures (14)

  • Figure 1: Weight averaging v.s. equivalent learning rate schedule. The dotted vertical red lines show the set of independent trajectories "sidetrips" with appropriate learning rate schedule that start at corresponding point in the main trajectory.
  • Figure 2: Loss as a function of a step size in a single step gradient descent in the middle of a trajectory. The step shown here is step $75\,000$ out of $100\,000$ run. The behavior is typical for other steps as well. The small rectangle on the left graph approximately shows the location of the zoomed-in right graph. The unseen batches is computed over $10$ batches, with dark shade showing the standard deviation from the mean, while light shade shows the max/min values observed.
  • Figure 3: Loss for a fixed batch on the interpolation between two points of the original training trajectory. We use two points along the ImageNet training trajectory with $t=75\,000$, and $t=85\,000$. From each point we perform 9 steps of gradient descent on a fixed batch and measure the loss on interpolation between corresponding pairs of points of each trajectory. The left graph shows the loss on the training batch, and note how within just 3 steps it reaches nearly 0. The middle graph shows the loss on held out. At step 0 the training and held out batch loss profiles are very similar. Rightmost graph shows the visualization of the process.
  • Figure 4: Stochastic gradient descent. \ref{['fig:drag-drift-interp-2d']} shows how stochastic gradient direction decomposes into a drift and a drag component. The magnitude of the drag component changes as $\sqrt{\lambda}$, while the magnitude of the drift component is proportional to $\lambda$. Further, the drift component has a quadratic bias to increase loss, thus enabling us to pick learning rate that allows to reduce the loss. \ref{['fig:drag-drift-interp-3d']} illustrates how stationary trajectory traverses the sphere of a fixed loss, while taking a midpoint enables to get inside the sphere. Best viewed in color.
  • Figure 5: Comparing learning rate schedule with aggregation. The dotted vertical red lines show the set of independent trajectories "side-trips" with appropriate learning rate schedule that start at corresponding point in the main trajectory. For midpoint and average, the red circles at midpoint show the accuracy at 1, 10, 100 and 1000 steps. For EMA the side-trips are 3000 steps. The alternating dash/dot show the running averages of the main trajectory (solid line, bottom).
  • ...and 9 more figures

Theorems & Definitions (2)

  • Lemma 1
  • Lemma 2