Table of Contents
Fetching ...

Image Classification at Supercomputer Scale

Chris Ying, Sameer Kumar, Dehao Chen, Tao Wang, Youlong Cheng

TL;DR

<3-5 sentence high-level summary> Scaling deep learning to petascale hardware poses algorithmic and systems challenges. The paper presents three system-level optimizations—distributed batch normalization, input-pipeline improvements, and a 2-D torus all-reduce—to enable massive data-parallel training on TPU clusters. Through controlled experiments and a full integration, the authors demonstrate ResNet-50 on ImageNet achieving 76.3% accuracy in 2.2 minutes on a 1024-chip TPU v3 Pod, with throughput exceeding 1.05M images/s and no accuracy loss. The results illustrate that with careful systems design, large-batch training can maintain model quality while delivering practical, world-record training speedups, with methods generalizable to non-TPU hardware.

Abstract

Deep learning is extremely computationally intensive, and hardware vendors have responded by building faster accelerators in large clusters. Training deep learning models at petaFLOPS scale requires overcoming both algorithmic and systems software challenges. In this paper, we discuss three systems-related optimizations: (1) distributed batch normalization to control per-replica batch sizes, (2) input pipeline optimizations to sustain model throughput, and (3) 2-D torus all-reduce to speed up gradient summation. We combine these optimizations to train ResNet-50 on ImageNet to 76.3% accuracy in 2.2 minutes on a 1024-chip TPU v3 Pod with a training throughput of over 1.05 million images/second and no accuracy drop.

Image Classification at Supercomputer Scale

TL;DR

<3-5 sentence high-level summary> Scaling deep learning to petascale hardware poses algorithmic and systems challenges. The paper presents three system-level optimizations—distributed batch normalization, input-pipeline improvements, and a 2-D torus all-reduce—to enable massive data-parallel training on TPU clusters. Through controlled experiments and a full integration, the authors demonstrate ResNet-50 on ImageNet achieving 76.3% accuracy in 2.2 minutes on a 1024-chip TPU v3 Pod, with throughput exceeding 1.05M images/s and no accuracy loss. The results illustrate that with careful systems design, large-batch training can maintain model quality while delivering practical, world-record training speedups, with methods generalizable to non-TPU hardware.

Abstract

Deep learning is extremely computationally intensive, and hardware vendors have responded by building faster accelerators in large clusters. Training deep learning models at petaFLOPS scale requires overcoming both algorithmic and systems software challenges. In this paper, we discuss three systems-related optimizations: (1) distributed batch normalization to control per-replica batch sizes, (2) input pipeline optimizations to sustain model throughput, and (3) 2-D torus all-reduce to speed up gradient summation. We combine these optimizations to train ResNet-50 on ImageNet to 76.3% accuracy in 2.2 minutes on a 1024-chip TPU v3 Pod with a training throughput of over 1.05 million images/second and no accuracy drop.

Paper Structure

This paper contains 20 sections, 7 figures, 1 table.

Figures (7)

  • Figure 1: Cloud TPU v2 device with four chips, 180 teraFLOPS of peak floating point throughput and 64 GB of High Bandwidth Memory (HBM).
  • Figure 2: Liquid-cooled Cloud TPU v3 device with four chips, 420 teraFLOPS of peak floating point throughput and 128 GB of HBM.
  • Figure 3: Distributed batch normalization algorithm. In this example, the subgroup size is 2.
  • Figure 4: [best viewed in color] 2-D all-reduce across a hypothetical $3 \times 3$ torus. In the first phase (left), the first half of the tensors (blue) are summed along the vertical dimension while the second half of the tensors (red) are summed concurrently along the horizontal dimension. In the second phase (right), the dimensions are flipped, which completes the all-reduce for both halves.
  • Figure 5: Error bars show 25% and 75% quartiles. (left) Controlled additions of each of the 4 optimizations. (center) Controlled ablations of each of the 4 optimizations, the yellow bar indicates the final pipeline that we used which includes all optimizations. (right) Varying the number of parallel threads with all other optimizations enabled.
  • ...and 2 more figures