Table of Contents
Fetching ...

Batch size invariant Adam

Xi Wang, Laurence Aitchison

TL;DR

The paper tackles batch size invariance in Adam for large-scale distributed training. It introduces Batch size invariant Adam, which computes the second-moment estimate hat{V} using the average of squared micro-batch gradients rather than the square of the average gradient, yielding invariance to mini-batch size under mild update-size constraints. The authors provide theoretical results showing equivalence between multi-step micro-batch updates and a single invariant update, along with empirical validation on ResNet-18 and ViT across CIFAR-10 tasks, where the invariant variant remains stable across varying batch sizes while standard Adam with sqrt(B) scaling exhibits discrepancies. The work offers practical benefits for hyperparameter transfer and potential memory savings in server-side gradient accumulation, contributing a robust alternative to existing batch-size correction methods in adaptive optimizers.

Abstract

We propose a batch size invariant version of Adam, for use in large-scale, distributed settings, in which the mini-batch is divided into micro-batches which are distributed among worker nodes. For the v term, standard Adam first computes the average over micro-batch gradients, then squares, while in the batch size invariant Adam proposed here, we first square the micro-batch gradients, then average. Previous work (e.g. Malladi et al. 2022) used an alternative approach that involved a square-root scaling of the learning rate, but this approach requires strong assumptions to work; in particular that the gradient variance dominates the square of the expected gradient. In contrast, the approach proposed here gives batch size invariance without this assumption. We confirm that in practice our scheme gives batch size invariance in a much larger range of scenarios than the previous approach.

Batch size invariant Adam

TL;DR

The paper tackles batch size invariance in Adam for large-scale distributed training. It introduces Batch size invariant Adam, which computes the second-moment estimate hat{V} using the average of squared micro-batch gradients rather than the square of the average gradient, yielding invariance to mini-batch size under mild update-size constraints. The authors provide theoretical results showing equivalence between multi-step micro-batch updates and a single invariant update, along with empirical validation on ResNet-18 and ViT across CIFAR-10 tasks, where the invariant variant remains stable across varying batch sizes while standard Adam with sqrt(B) scaling exhibits discrepancies. The work offers practical benefits for hyperparameter transfer and potential memory savings in server-side gradient accumulation, contributing a robust alternative to existing batch-size correction methods in adaptive optimizers.

Abstract

We propose a batch size invariant version of Adam, for use in large-scale, distributed settings, in which the mini-batch is divided into micro-batches which are distributed among worker nodes. For the v term, standard Adam first computes the average over micro-batch gradients, then squares, while in the batch size invariant Adam proposed here, we first square the micro-batch gradients, then average. Previous work (e.g. Malladi et al. 2022) used an alternative approach that involved a square-root scaling of the learning rate, but this approach requires strong assumptions to work; in particular that the gradient variance dominates the square of the expected gradient. In contrast, the approach proposed here gives batch size invariance without this assumption. We confirm that in practice our scheme gives batch size invariance in a much larger range of scenarios than the previous approach.
Paper Structure (15 sections, 2 theorems, 55 equations, 4 figures, 3 algorithms)

This paper contains 15 sections, 2 theorems, 55 equations, 4 figures, 3 algorithms.

Key Result

Theorem 1

Consider two optimizers: Micro Adam (Alg. alg:micro_adamw; i.e. standard Adam applied to micro-batches) with hyperparameters $\eta$, $\gamma_1$ and $\gamma_2$, and batch size invariant Adam (Alg. alg:invariant_adamw) with hyperparameters applied to mini-batches composed of $\kappa$micro-batches. We start both optimizers at time $t-\kappa$ at the same initial state, $w_{t-\kappa}$, $m_{t-\kappa}$

Figures (4)

  • Figure 1: Comparing the behavior of our proposed batch size invariant Adam (right column), with $\eta \propto B$ against standard Adam (left column), with $\eta \propto \sqrt{B}$granziol2022learningmalladi2022sdeshilton2022batch. The model was a ResNet-18 trained on CIFAR-10 over 200 epochs, under batch sizes (opacity) ranging from $B=25$ to $B=800$, with $B_\mathrm{max}=800$. Each color represents a different base learning rate $\eta_0$. Note that batch size invariant Adam (right) gives almost perfect batch size invariance (in that the lines are all on top of each other) up until $\eta=10^{-3} \times (B/B_\text{max})$. In contrast, with standard Adam (bottom), you get discrepancies even with the smallest learning rate, i.e. $\eta_0 = 10^{-6}$.
  • Figure 2: As Fig. \ref{['fig:resnet']}, but with layernorm rather than batchnorm. In particular, we compare the behavior of our proposed batch size invariant Adam, with $\eta \propto B$ (right) against standard Adam, with $\eta \propto \sqrt{B}$ (left), both with $B_{\rm max} = 800$. Similar to the batchnorm results, the batch size invariant Adam lines (right) almost exactly line up, until $\eta=10^{-3} \times (B/B_\text{max})$. Whereas standard Adam (left) shows discrepancies between lines even under the smallest learning rate considered ($\eta_0 = 10^{-6}$).
  • Figure 3: As Fig. \ref{['fig:resnet']}, but use ViT rather than ResNet-18. Again, we compare the behavior of our proposed batch size invariant Adam, with $\eta \propto B$ (right) against standard Adam, with $\eta \propto \sqrt{B}$ (left). A $B_\text{max}$ of 400 is used. Standard Adam shows aligned trajectories for the smallest and largest base learning rate but shows significant discrepancy at $\eta_0 = 10^{-5}$ and in the early stage of $\eta_0=10^{-4}$. This is expected, as obtaining batch size invariance with standard Adam under the square-root scaling requires the gradient variance dominating the gradient mean, which may not hold when the parameters are near initialization, where the gradient may have a large magnitude. Regardless, our proposed batchsize invariant Adam shows consistency across all learning rates considered at all stages in the optimization.
  • Figure : Micro Adam (i.e. standard Adam applied to micro-batches)

Theorems & Definitions (2)

  • Theorem 1
  • Theorem 2