Table of Contents
Fetching ...

Augment your batch: better training with larger batches

Elad Hoffer, Tal Ben-Nun, Itay Hubara, Niv Giladi, Torsten Hoefler, Daniel Soudry

TL;DR

<3-5 sentence high-level summary> The paper tackles the slowdown and generalization challenges of large-batch SGD by proposing Batch Augmentation (BA), which repeats each input within a batch under multiple data augmentations to form a larger effective batch without extra I/O. BA acts as a variance-reducing regularizer, maintaining correlated gradients across augmented copies and mitigating convergence issues associated with high lambda_max and fixed learning-rate schedules. Through theoretical discussion and extensive experiments on CIFAR-10/100, ImageNet, PTB, and distributed systems, BA demonstrates faster convergence, improved final accuracy, and better hardware utilization, even at very large batch sizes. The work shows BA can approximate or exceed the benefits of regime adaptation with less hyperparameter tuning, enabling scalable, efficient training on modern HPC resources.</paper_summary>

Abstract

Large-batch SGD is important for scaling training of deep neural networks. However, without fine-tuning hyperparameter schedules, the generalization of the model may be hampered. We propose to use batch augmentation: replicating instances of samples within the same batch with different data augmentations. Batch augmentation acts as a regularizer and an accelerator, increasing both generalization and performance scaling. We analyze the effect of batch augmentation on gradient variance and show that it empirically improves convergence for a wide variety of deep neural networks and datasets. Our results show that batch augmentation reduces the number of necessary SGD updates to achieve the same accuracy as the state-of-the-art. Overall, this simple yet effective method enables faster training and better generalization by allowing more computational resources to be used concurrently.

Augment your batch: better training with larger batches

TL;DR

<3-5 sentence high-level summary> The paper tackles the slowdown and generalization challenges of large-batch SGD by proposing Batch Augmentation (BA), which repeats each input within a batch under multiple data augmentations to form a larger effective batch without extra I/O. BA acts as a variance-reducing regularizer, maintaining correlated gradients across augmented copies and mitigating convergence issues associated with high lambda_max and fixed learning-rate schedules. Through theoretical discussion and extensive experiments on CIFAR-10/100, ImageNet, PTB, and distributed systems, BA demonstrates faster convergence, improved final accuracy, and better hardware utilization, even at very large batch sizes. The work shows BA can approximate or exceed the benefits of regime adaptation with less hyperparameter tuning, enabling scalable, efficient training on modern HPC resources.</paper_summary>

Abstract

Large-batch SGD is important for scaling training of deep neural networks. However, without fine-tuning hyperparameter schedules, the generalization of the model may be hampered. We propose to use batch augmentation: replicating instances of samples within the same batch with different data augmentations. Batch augmentation acts as a regularizer and an accelerator, increasing both generalization and performance scaling. We analyze the effect of batch augmentation on gradient variance and show that it empirically improves convergence for a wide variety of deep neural networks and datasets. Our results show that batch augmentation reduces the number of necessary SGD updates to achieve the same accuracy as the state-of-the-art. Overall, this simple yet effective method enables faster training and better generalization by allowing more computational resources to be used concurrently.

Paper Structure

This paper contains 18 sections, 1 theorem, 23 equations, 6 figures, 4 tables.

Key Result

Theorem 1

The iterates of SGD (Eq. eq: SGD-1-1) will converge if In addition, this bound is tight in the sense that it is also a necessary condition for certain datasets.

Figures (6)

  • Figure 1: Impact of Batch Augmentation (BA, with M=4) on ResNet-50 and ImageNet, showing training (dashed) and validation error (solid).
  • Figure 2: Comparison of gradient $L^2$ norm (ResNet44 + cutout, Cifar10, $B=64$) between the baseline ($M=1$) and batch augmentation with $M \in \{2,4,8,16,32\}$
  • Figure 3: Impact of batch augmentation (ResNet44 + cutout, Cifar10). We used the original (red) training regime with $B=64$, and compared to batch augmentation with $M \in \{2,4,8,16,32\}$ creating an effective batch of $64\cdot M$
  • Figure 4: A comparison between (1) baseline B=64 training (2) our batch augmentation (BA) method with M=10 (3) regime adaptation (RA) with B=640 and 10x more epochs.
  • Figure 5: Training (dashed) and validation error over time (in hours) of ResNet50 with $B=256$ and $M=4$ (Red) vs $M=10$ (Blue). Difference in runtime is negligible, while higher batch augmentation reaches lower error. Runtime for Baseline ($M=1$): $1.43\pm 0.13$ steps/second, $M=4$: $1.47\pm 0.13$ steps/second, $M=10$: $1.46\pm 0.14$ steps/second.
  • ...and 1 more figures

Theorems & Definitions (1)

  • Theorem 1