Image Classification at Supercomputer Scale
Chris Ying, Sameer Kumar, Dehao Chen, Tao Wang, Youlong Cheng
TL;DR
<3-5 sentence high-level summary> Scaling deep learning to petascale hardware poses algorithmic and systems challenges. The paper presents three system-level optimizations—distributed batch normalization, input-pipeline improvements, and a 2-D torus all-reduce—to enable massive data-parallel training on TPU clusters. Through controlled experiments and a full integration, the authors demonstrate ResNet-50 on ImageNet achieving 76.3% accuracy in 2.2 minutes on a 1024-chip TPU v3 Pod, with throughput exceeding 1.05M images/s and no accuracy loss. The results illustrate that with careful systems design, large-batch training can maintain model quality while delivering practical, world-record training speedups, with methods generalizable to non-TPU hardware.
Abstract
Deep learning is extremely computationally intensive, and hardware vendors have responded by building faster accelerators in large clusters. Training deep learning models at petaFLOPS scale requires overcoming both algorithmic and systems software challenges. In this paper, we discuss three systems-related optimizations: (1) distributed batch normalization to control per-replica batch sizes, (2) input pipeline optimizations to sustain model throughput, and (3) 2-D torus all-reduce to speed up gradient summation. We combine these optimizations to train ResNet-50 on ImageNet to 76.3% accuracy in 2.2 minutes on a 1024-chip TPU v3 Pod with a training throughput of over 1.05 million images/second and no accuracy drop.
