AdaPM: a Partial Momentum Algorithm for LLM Training
Yimu Zhang, Yuanshi Liu, Cong Fang
TL;DR
AdaPM tackles the memory bottleneck of momentum-based optimizers in LLM training by introducing a non-uniform, adaptive momentum design across Transformer blocks and a debiased low-rank estimator. By disabling momentum for embedding and attention-output blocks, applying a low-rank, bias-corrected momentum to Q/K/MLP blocks, and preserving full momentum for Value blocks, it achieves substantial memory reductions while maintaining convergence comparable to AdamW. Empirical results across GPT-2 and Llama models show momentum memory savings above $90\%$ and optimizer-state savings up to $95\%$ when combined with Adam-mini, along with significant GPU-hour reductions and robust scaling to larger models and RLHF pipelines. The approach also demonstrates compatibility with second-order statistics reduction methods and provides a scalable pathway for efficient, large-scale LLM pretraining and fine-tuning.
Abstract
In the training of large language models, momentum is widely used and often demonstrated to achieve significant acceleration. However, storing momentum typically presents memory challenges. In this paper, we propose AdaPM, an adaptive training strategy that leverages partial momentum to implement a memory-efficient optimizer. To this end, AdaPM utilizes a non-uniform momentum design: for most blocks, full momentum is not necessary to preserve the performance of the optimization. In the momentum design of AdaPM, to mitigate the bias and performance loss caused by partial momentum, we enhance the partial momentum by a bias correction technique. Empirically, we verify that our approach reduces memory by over $90\%$ in momentum while maintaining both efficiency and performance for pretraining various language models ranging from 60M to 1.5B, as well as for supervised fine-tuning and RLHF. AdaPM can further reduce memory by up to $95\%$ in optimizer states by combining the memory-efficient technique on the second-order statistic, saving over $30\%$ GPU hours for pretraining GPT-2 1.5B.
