Table of Contents
Fetching ...

AdaPM: a Partial Momentum Algorithm for LLM Training

Yimu Zhang, Yuanshi Liu, Cong Fang

TL;DR

AdaPM tackles the memory bottleneck of momentum-based optimizers in LLM training by introducing a non-uniform, adaptive momentum design across Transformer blocks and a debiased low-rank estimator. By disabling momentum for embedding and attention-output blocks, applying a low-rank, bias-corrected momentum to Q/K/MLP blocks, and preserving full momentum for Value blocks, it achieves substantial memory reductions while maintaining convergence comparable to AdamW. Empirical results across GPT-2 and Llama models show momentum memory savings above $90\%$ and optimizer-state savings up to $95\%$ when combined with Adam-mini, along with significant GPU-hour reductions and robust scaling to larger models and RLHF pipelines. The approach also demonstrates compatibility with second-order statistics reduction methods and provides a scalable pathway for efficient, large-scale LLM pretraining and fine-tuning.

Abstract

In the training of large language models, momentum is widely used and often demonstrated to achieve significant acceleration. However, storing momentum typically presents memory challenges. In this paper, we propose AdaPM, an adaptive training strategy that leverages partial momentum to implement a memory-efficient optimizer. To this end, AdaPM utilizes a non-uniform momentum design: for most blocks, full momentum is not necessary to preserve the performance of the optimization. In the momentum design of AdaPM, to mitigate the bias and performance loss caused by partial momentum, we enhance the partial momentum by a bias correction technique. Empirically, we verify that our approach reduces memory by over $90\%$ in momentum while maintaining both efficiency and performance for pretraining various language models ranging from 60M to 1.5B, as well as for supervised fine-tuning and RLHF. AdaPM can further reduce memory by up to $95\%$ in optimizer states by combining the memory-efficient technique on the second-order statistic, saving over $30\%$ GPU hours for pretraining GPT-2 1.5B.

AdaPM: a Partial Momentum Algorithm for LLM Training

TL;DR

AdaPM tackles the memory bottleneck of momentum-based optimizers in LLM training by introducing a non-uniform, adaptive momentum design across Transformer blocks and a debiased low-rank estimator. By disabling momentum for embedding and attention-output blocks, applying a low-rank, bias-corrected momentum to Q/K/MLP blocks, and preserving full momentum for Value blocks, it achieves substantial memory reductions while maintaining convergence comparable to AdamW. Empirical results across GPT-2 and Llama models show momentum memory savings above and optimizer-state savings up to when combined with Adam-mini, along with significant GPU-hour reductions and robust scaling to larger models and RLHF pipelines. The approach also demonstrates compatibility with second-order statistics reduction methods and provides a scalable pathway for efficient, large-scale LLM pretraining and fine-tuning.

Abstract

In the training of large language models, momentum is widely used and often demonstrated to achieve significant acceleration. However, storing momentum typically presents memory challenges. In this paper, we propose AdaPM, an adaptive training strategy that leverages partial momentum to implement a memory-efficient optimizer. To this end, AdaPM utilizes a non-uniform momentum design: for most blocks, full momentum is not necessary to preserve the performance of the optimization. In the momentum design of AdaPM, to mitigate the bias and performance loss caused by partial momentum, we enhance the partial momentum by a bias correction technique. Empirically, we verify that our approach reduces memory by over in momentum while maintaining both efficiency and performance for pretraining various language models ranging from 60M to 1.5B, as well as for supervised fine-tuning and RLHF. AdaPM can further reduce memory by up to in optimizer states by combining the memory-efficient technique on the second-order statistic, saving over GPU hours for pretraining GPT-2 1.5B.

Paper Structure

This paper contains 21 sections, 3 theorems, 12 equations, 6 figures, 2 tables, 4 algorithms.

Key Result

Theorem 1

Set a constant stepsize of $\eta = \Theta(1)$ and the number of iterations $T$. Then the validation loss of vanilla SGD is bounded by $\tilde{\mathcal{O}}\left( T^{1/a - 1} + T^{1/a - b/a} \right)$. For the accelerated SGD method with momentum $1 - \beta$ (where $\beta \in (0,1]$), the validation lo

Figures (6)

  • Figure 1: AdaPM takes less memory and can reach higher throughput with on par or better performance than AdamW. (a) Results for GPT-2 1.5B pre-training. (b) The memory cost when training GPT-2 1.5B with various optimizers. The experimental details are shown in Section \ref{['sec:ex-pre']}. (c) AdaPM assigns different momentum designs to different blocks and enhances the partial momentum using a bias- corrected approach.
  • Figure 2: The spectral distribution of features in each block of 10th layer in GPT-2 124M at $10\%$ of the training steps.
  • Figure 3: (a)-(b): Loss curves of pre-training GPT-2 series from 124M to 330M. The 1.5B GPT-2 pretrain is in Section \ref{['sec:intro']}. (c)(d): Test loss of pre-training Llama-2 series from 130M to 340M. AdaPM performs on par or better than AdamW, while other methods perform worse.
  • Figure 4: (a) Loss curves of pre-training GPT-2 series with or without bias-correction. (b)Applying AdaPM for pretraining GPT-2-1.5B with different rank and update frequency $T$. (c) Applying AdaPM to Adam-mini for pretraining GPT-2-1.5B.
  • Figure 5: Scaling laws in terms of parameters in (a) suggest that AdaPM can be scaled up to larger models (if the scaling law holds). (b)(c): SFT,and RLHF when aligning Llama3-8B. AdaPM maintains similar evaluation perplexity and reward to AdamW with $43\%$ less memory.
  • ...and 1 more figures

Theorems & Definitions (5)

  • Theorem 1: Validation Loss Rates for SGD and Accelerated SGD
  • Theorem 3
  • Theorem 4: Upper Bound of Accelerated SGD
  • proof
  • proof