Table of Contents
Fetching ...

A geometric framework for momentum-based optimizers for low-rank training

Steffen Schotthöfer, Timon Klein, Jonas Kusch

TL;DR

The paper tackles the mismatch between momentum-based optimization and low-rank neural network parameterizations by developing a geometry-aware, dynamical low-rank training framework. It derives two practical optimizers—the low-rank heavy-ball and the low-rank Adam—that preserve low-rank structure via tangent-space projections and basis augmentation with rank-truncation, backed by convergence analysis and robust error bounds. Empirically, the methods deliver faster convergence and stronger validation under tight parameter budgets across vision and language tasks, including transfer learning, finetuning, and pretraining (e.g., VGG/ViT, DeBERTa, Llama2, GPT-2). This work offers a scalable, theoretically grounded alternative to full-rank training for large models, enabling efficient training on resource-constrained devices while maintaining competitive or superior performance.

Abstract

Low-rank pre-training and fine-tuning have recently emerged as promising techniques for reducing the computational and storage costs of large neural networks. Training low-rank parameterizations typically relies on conventional optimizers such as heavy ball momentum methods or Adam. In this work, we identify and analyze potential difficulties that these training methods encounter when used to train low-rank parameterizations of weights. In particular, we show that classical momentum methods can struggle to converge to a local optimum due to the geometry of the underlying optimization landscape. To address this, we introduce novel training strategies derived from dynamical low-rank approximation, which explicitly account for the underlying geometric structure. Our approach leverages and combines tools from dynamical low-rank approximation and momentum-based optimization to design optimizers that respect the intrinsic geometry of the parameter space. We validate our methods through numerical experiments, demonstrating faster convergence, and stronger validation metrics at given parameter budgets.

A geometric framework for momentum-based optimizers for low-rank training

TL;DR

The paper tackles the mismatch between momentum-based optimization and low-rank neural network parameterizations by developing a geometry-aware, dynamical low-rank training framework. It derives two practical optimizers—the low-rank heavy-ball and the low-rank Adam—that preserve low-rank structure via tangent-space projections and basis augmentation with rank-truncation, backed by convergence analysis and robust error bounds. Empirically, the methods deliver faster convergence and stronger validation under tight parameter budgets across vision and language tasks, including transfer learning, finetuning, and pretraining (e.g., VGG/ViT, DeBERTa, Llama2, GPT-2). This work offers a scalable, theoretically grounded alternative to full-rank training for large models, enabling efficient training on resource-constrained devices while maintaining competitive or superior performance.

Abstract

Low-rank pre-training and fine-tuning have recently emerged as promising techniques for reducing the computational and storage costs of large neural networks. Training low-rank parameterizations typically relies on conventional optimizers such as heavy ball momentum methods or Adam. In this work, we identify and analyze potential difficulties that these training methods encounter when used to train low-rank parameterizations of weights. In particular, we show that classical momentum methods can struggle to converge to a local optimum due to the geometry of the underlying optimization landscape. To address this, we introduce novel training strategies derived from dynamical low-rank approximation, which explicitly account for the underlying geometric structure. Our approach leverages and combines tools from dynamical low-rank approximation and momentum-based optimization to design optimizers that respect the intrinsic geometry of the parameter space. We validate our methods through numerical experiments, demonstrating faster convergence, and stronger validation metrics at given parameter budgets.

Paper Structure

This paper contains 35 sections, 4 theorems, 50 equations, 4 figures, 13 tables, 6 algorithms.

Key Result

Theorem 1

Let $W(t)$ be the solution of Eq. eq:gradflowopt and let $\mathcal{L}$ be bounded from below. Then, $W(t)$ converges to a $W^{\star}$ which fulfills the low-rank optimality condition

Figures (4)

  • Figure 1: Geometric interpretation of \ref{['alg:heavy_ball_dlrt']}. We compute the parametrization of the tangent plane $\mathcal{T}_{\mathcal{M}_r}$. Then, we compute the projected gradient $\nabla_{\bar{S}}\mathcal{L}$ to construct the low-rank momentum update. The momentum optimizer is then applied to the low-rank weight coefficient $\widehat{S}$. Lastly, we retract the updated coefficients back onto the manifold $\mathcal{M}_r$. The interpretation of \ref{['alg:adam_dlrt']} is analogous. LoRA-like methods do not employ orthogonal projections onto $\mathcal{T}_{\mathcal{M}_r}$, but instead map the full gradient $\nabla_W\mathcal{L}$ implicitly onto $\mathcal{M}_r$. The linear map (displayed as the wavy orange line) may map the gradient direction far away from the properly projected gradient flow, leading to suboptimal descent directions.
  • Figure 2: ViT-L.32 on ImageNet1k, pretrained from scratch in low-rank and full-rank baseline format for 4000 iterations. Training loss and accuracy of \ref{['alg:adam_dlrt']} is close to the full-rank baseline, whereas LoRA pretraining struggles to converge within the training time budget.
  • Figure 3: ViT-small on Cifar10, pretrained from scratch in low-rank and full-rank baseline format for 450 epochs. Median trajectory over 5 runs. \ref{['alg:adam_dlrt']} and LoRA pretraining initially converge faster than the full-rank baseline. After the initial warm-up phase, \ref{['alg:adam_dlrt']} exhibits a steeper convergence slope than LoRA. Moreover, \ref{['alg:adam_dlrt']} achieves lower loss and higher validation accuracy than LoRA, even surpassing the baseline. A naive DLRT implementation with Adam leads to slower convergence and over 10% drop in validation accuracy.
  • Figure 4: GPT2 reproduction on OpenWebText, pretrained from scratch in low-rank, full-rank baseline and \ref{['alg:adam_dlrt']} for 15000 iterations. \ref{['alg:adam_dlrt']} method significantly outperforms LoRA pretraining (best validation loss $3.4642$ vs. $4.8141$), while incurring only a moderate increase relative to the full-rank baseline ($3.4642$ vs. $3.2313$).

Theorems & Definitions (8)

  • Theorem 1: Convergence
  • proof
  • Theorem 2: Convergence of low-rank factors
  • proof
  • Theorem 3: Error-bound
  • proof
  • Theorem 4
  • proof