Lean and Mean Adaptive Optimization via Subset-Norm and Subspace-Momentum with Convergence Guarantees
Thien Hang Nguyen, Huy Le Nguyen
TL;DR
This work tackles the memory bottlenecks of adaptive optimizers in large-scale neural networks by introducing Subset-Norm (SN) and Subspace-Momentum (SM). SN reduces AdaGrad-style memory from $O(d)$ to $O(\sqrt{d})$ by sharing step sizes across parameter subsets, with a high-probability convergence guarantee under coordinate-wise sub-Gaussian noise; SM confines momentum to a low-dimensional subspace of dimension $k$ while performing SGD in the orthogonal complement, with a convergence guarantee similar to SGD. The combination, SNSM, further reduces memory to roughly $k+\sqrt{d}$ and yields practical gains in LLM pretraining and fine-tuning, including substantial memory savings and faster attainment of competitive perplexities. Theoretical results are complemented by extensive experiments on LLaMA-scale models, showing that SN/SM not only saves memory but can improve training efficiency and perplexity with minimal hyperparameter tuning. Overall, the paper provides a principled, convergent framework for memory-efficient optimization in deep learning, with strong empirical support for real-world large-scale training settings.
Abstract
We introduce two complementary techniques for efficient optimization that reduce memory requirements while accelerating training of large-scale neural networks. The first technique, Subset-Norm step size, generalizes AdaGrad-Norm and AdaGrad(-Coordinate) through step-size sharing. Subset-Norm (SN) reduces AdaGrad's memory footprint from $O(d)$ to $O(\sqrt{d})$, where $d$ is the model size. For non-convex smooth objectives under coordinate-wise sub-gaussian noise, we show a noise-adapted high-probability convergence guarantee with improved dimensional dependence of SN over existing methods. Our second technique, Subspace-Momentum, reduces the momentum state's memory footprint by restricting momentum to a low-dimensional subspace while performing SGD in the orthogonal complement. We prove a high-probability convergence result for Subspace-Momentum under standard assumptions. Empirical evaluation on pre-training and fine-tuning LLMs demonstrates the effectiveness of our methods. For instance, combining Subset-Norm with Subspace-Momentum achieves Adam's validation perplexity for LLaMA 1B in approximately half the training tokens (6.8B vs 13.1B) while reducing Adam's optimizer-states memory footprint by more than 80\% with minimal additional hyperparameter tuning.
