Table of Contents
Fetching ...

Don't Use Large Mini-Batches, Use Local SGD

Tao Lin, Sebastian U. Stich, Kumar Kshitij Patel, Martin Jaggi

TL;DR

The paper tackles the generalization gap observed when scaling stochastic gradient methods with very large mini-batches in distributed deep learning. It proposes local SGD and a two-phase variant called post-local SGD, plus hierarchical local SGD for heterogeneous systems, to balance computation and communication while preserving or enhancing generalization. Empirical results across CIFAR, ImageNet, and language modeling show that local SGD improves time-to-accuracy and that post-local SGD closes the large-batch generalization gap, often outperforming small and large-batch baselines. The work also provides insights into why these methods generalize better, linking local updates to structured stochastic noise and flatter minima, with practical benefits for scalable, communication-efficient distributed training. Overall, the approach offers a principled route to scalable, generalizable distributed learning without resorting to excessively small batches or ad-hoc tuning.

Abstract

Mini-batch stochastic gradient methods (SGD) are state of the art for distributed training of deep neural networks. Drastic increases in the mini-batch sizes have lead to key efficiency and scalability gains in recent years. However, progress faces a major roadblock, as models trained with large batches often do not generalize well, i.e. they do not show good accuracy on new data. As a remedy, we propose a \emph{post-local} SGD and show that it significantly improves the generalization performance compared to large-batch training on standard benchmarks while enjoying the same efficiency (time-to-accuracy) and scalability. We further provide an extensive study of the communication efficiency vs. performance trade-offs associated with a host of \emph{local SGD} variants.

Don't Use Large Mini-Batches, Use Local SGD

TL;DR

The paper tackles the generalization gap observed when scaling stochastic gradient methods with very large mini-batches in distributed deep learning. It proposes local SGD and a two-phase variant called post-local SGD, plus hierarchical local SGD for heterogeneous systems, to balance computation and communication while preserving or enhancing generalization. Empirical results across CIFAR, ImageNet, and language modeling show that local SGD improves time-to-accuracy and that post-local SGD closes the large-batch generalization gap, often outperforming small and large-batch baselines. The work also provides insights into why these methods generalize better, linking local updates to structured stochastic noise and flatter minima, with practical benefits for scalable, communication-efficient distributed training. Overall, the approach offers a principled route to scalable, generalizable distributed learning without resorting to excessively small batches or ad-hoc tuning.

Abstract

Mini-batch stochastic gradient methods (SGD) are state of the art for distributed training of deep neural networks. Drastic increases in the mini-batch sizes have lead to key efficiency and scalability gains in recent years. However, progress faces a major roadblock, as models trained with large batches often do not generalize well, i.e. they do not show good accuracy on new data. As a remedy, we propose a \emph{post-local} SGD and show that it significantly improves the generalization performance compared to large-batch training on standard benchmarks while enjoying the same efficiency (time-to-accuracy) and scalability. We further provide an extensive study of the communication efficiency vs. performance trade-offs associated with a host of \emph{local SGD} variants.

Paper Structure

This paper contains 82 sections, 15 equations, 20 figures, 16 tables, 5 algorithms.

Figures (20)

  • Figure 1: Illustration of the generalization gap. Large-batch SGD (A2, blue) matches the training curves of small-batch SGD (A1, green), i.e. has no optimization difficulty (left & middle). However, it does not reach the same test accuracy (right) while the proposed post-local SGD (A5, red) does. Post-local SGD (A5) is defined by starting local SGD from the model obtained by large-batch SGD (A2) at epoch $150$. Mini-batch SGD with larger mini-batch size (A3, yellow) even suffers from optimization issues. Experiments are for ResNet-20 on CIFAR-10 ($B_{\text{loc}}\!=\!128$), with fined-tuned learning rate for mini-batch SGD with the warmup scheme in goyal2017accurate. The inline table highlights the comparison of system/generalization performance for different algorithms.
  • Figure 2: Scaling behavior of local SGD in clock-time for increasing number of workers $K$, for different number of local steps $H$, for training ResNet-20 on CIFAR-10 with $B_{\text{loc}}\!=\!128$. The reported speedup (averaged over three runs) is over single GPU training time for reaching the baseline top-1 test accuracy ($91.2\%$ as in he2016deep). We use a $8 \!\times\! 2$-GPU cluster with $10$ Gbps network. $H\!=\!1$ recovers mini-batch SGD.
  • Figure 3: Training ResNet-20 on CIFAR-10 under different $K$ and $H$, with fixed $B_{\text{loc}}=128$. All results are averaged over three runs and all settings access to the same total number of training samples. We fine-tune the learning rate of mini-batch SGD for each setting.
  • Figure 4: Top-1 test accuracy of training ResNet-20 on CIFAR-10. Box-plot figures are derived from 3 runs.
  • Figure 5: Understanding the generalization ability of post-local SGD for large-batch training (ResNet-20 on CIFAR-10 with $B K \!=\! B_{\text{loc}} K \!=\! 2048$). We use fixed $B \!=\! B_{\text{loc}} \!=\! 128$ with $K\!=\!16$ GPUs. The detailed experimental setup as well as more visualization of results are available in Appendix \ref{['subsec:complete_post_localsgd_generalization']}.
  • ...and 15 more figures