A geometric framework for momentum-based optimizers for low-rank training
Steffen Schotthöfer, Timon Klein, Jonas Kusch
TL;DR
The paper tackles the mismatch between momentum-based optimization and low-rank neural network parameterizations by developing a geometry-aware, dynamical low-rank training framework. It derives two practical optimizers—the low-rank heavy-ball and the low-rank Adam—that preserve low-rank structure via tangent-space projections and basis augmentation with rank-truncation, backed by convergence analysis and robust error bounds. Empirically, the methods deliver faster convergence and stronger validation under tight parameter budgets across vision and language tasks, including transfer learning, finetuning, and pretraining (e.g., VGG/ViT, DeBERTa, Llama2, GPT-2). This work offers a scalable, theoretically grounded alternative to full-rank training for large models, enabling efficient training on resource-constrained devices while maintaining competitive or superior performance.
Abstract
Low-rank pre-training and fine-tuning have recently emerged as promising techniques for reducing the computational and storage costs of large neural networks. Training low-rank parameterizations typically relies on conventional optimizers such as heavy ball momentum methods or Adam. In this work, we identify and analyze potential difficulties that these training methods encounter when used to train low-rank parameterizations of weights. In particular, we show that classical momentum methods can struggle to converge to a local optimum due to the geometry of the underlying optimization landscape. To address this, we introduce novel training strategies derived from dynamical low-rank approximation, which explicitly account for the underlying geometric structure. Our approach leverages and combines tools from dynamical low-rank approximation and momentum-based optimization to design optimizers that respect the intrinsic geometry of the parameter space. We validate our methods through numerical experiments, demonstrating faster convergence, and stronger validation metrics at given parameter budgets.
