DeMo: Decoupled Momentum Optimization
Bowen Peng, Lizhang Chen, Baiyu Su, Jeffrey Quesnelle, Diederik P. Kingma, Qiang Liu
TL;DR
Addresses the gradient-communication bottleneck in synchronous data-parallel training of large language models. Introduces DeMo, which decouples momentum updates and communicates compressed momentum using a blockwise transform (e.g., DCT) and top-$k$ sparsification, with momentum subtraction serving as implicit error feedback. The update leverages the base optimizer via a momentum-informed step, with a convergence guarantee of $O(1/\sqrt{T})$ under standard assumptions. Empirically, on 300M and 1B-parameter decoder-only models, DeMo achieves up to 85x reductions in per-GPU data transfer compared to AdamW-DDP while preserving loss and accuracy; at the 1B scale, certain settings even yield better downstream performance with modest communication. DeMo is topology-agnostic, enabling training across multi-datacenter or Ethernet-based networks, and the authors release code to support reproducibility.
Abstract
Scaling neural network training increasingly depends on synchronous data-parallelism, yet full-precision gradient all-reduce imposes a severe communication bottleneck. We propose Decoupled Momentum Optimization (DeMo), a drop-in replacement for any momentum-based optimizers that significantly reduces the communication bandwidth while maintaining convergence. DeMo (i) decouples local momentum updates, (ii) applies a fast orthonormal transform (e.g., DCT) followed by top-k sparsification, and (iii) reuses the momentum buffer as error feedback via momentum subtraction. This design reduces per-step communication by up to two orders of magnitude with minimal computational overhead. Experiments on 300M and 1B-parameter DeMo language models show DeMo transmits up to 85x less data per GPU than AdamW-DDP while achieving comparable loss and accuracy. DeMo is topology-agnostic and enables training across multi-datacenter or Ethernet-based setups. Code is available at https://github.com/bloc97/DeMo
