Table of Contents
Fetching ...

Randomized Gradient Subspaces for Efficient Large Language Model Training

Sahar Rajabi, Nayeema Nonta, Samanvay Vajpayee, Sirisha Rambhatla

TL;DR

This work tackles the memory bottleneck in large language model training by analyzing the gradient subspace dynamics and introducing two randomized, subspace-aware methods, GrassWalk and GrassJump. Through empirical analysis on LLaMA pretraining, it shows that a core subspace captures most gradient energy early but loses dominance over time and in deeper layers, with the gradient space exhibiting near-flat curvature. The authors propose a Grassmannian-based framework that random-walks or random-jumps on subspaces, coupled with optimizer adaptations (AO) and information-recovery scaling (RS), to preserve learning signals while achieving state-of-the-art memory efficiency. Empirical results on LLaMA-1B and 7B demonstrate strong memory savings and faster convergence, suggesting that randomized strategies, when aligned with gradient dynamics, can be principled and effective for scalable LLM training.

Abstract

Training large language models (LLMs) is often bottlenecked by extreme memory demands, with optimizer states dominating the footprint. Recent works mitigates this cost by projecting gradients into low-dimensional subspaces using sophisticated update strategies. In this paper, we analyze the dynamics of gradient space and its underlying subspaces. We find that while a small subspace captures most gradient energy, a significant portion still resides in the residual bulk; moreover, the influence of the core subspace diminishes over time and in deeper layers. We also observe that the gradient space exhibits near-flat curvature, calling for algorithms that explicitly account for this geometry. Motivated by these insights, we introduce a suite of randomized algorithms, GrassWalk and GrassJump, which exploit subspace and achieve state-of-the-art memory savings while improving performance on LLaMA-1B and LLaMA-7B pretraining.

Randomized Gradient Subspaces for Efficient Large Language Model Training

TL;DR

This work tackles the memory bottleneck in large language model training by analyzing the gradient subspace dynamics and introducing two randomized, subspace-aware methods, GrassWalk and GrassJump. Through empirical analysis on LLaMA pretraining, it shows that a core subspace captures most gradient energy early but loses dominance over time and in deeper layers, with the gradient space exhibiting near-flat curvature. The authors propose a Grassmannian-based framework that random-walks or random-jumps on subspaces, coupled with optimizer adaptations (AO) and information-recovery scaling (RS), to preserve learning signals while achieving state-of-the-art memory efficiency. Empirical results on LLaMA-1B and 7B demonstrate strong memory savings and faster convergence, suggesting that randomized strategies, when aligned with gradient dynamics, can be principled and effective for scalable LLM training.

Abstract

Training large language models (LLMs) is often bottlenecked by extreme memory demands, with optimizer states dominating the footprint. Recent works mitigates this cost by projecting gradients into low-dimensional subspaces using sophisticated update strategies. In this paper, we analyze the dynamics of gradient space and its underlying subspaces. We find that while a small subspace captures most gradient energy, a significant portion still resides in the residual bulk; moreover, the influence of the core subspace diminishes over time and in deeper layers. We also observe that the gradient space exhibits near-flat curvature, calling for algorithms that explicitly account for this geometry. Motivated by these insights, we introduce a suite of randomized algorithms, GrassWalk and GrassJump, which exploit subspace and achieve state-of-the-art memory savings while improving performance on LLaMA-1B and LLaMA-7B pretraining.

Paper Structure

This paper contains 7 sections, 11 equations, 4 figures, 2 tables.

Figures (4)

  • Figure 1: Each decoder layer stack includes seven layer types. In the 1B model, the plots show the fraction of gradient-matrix energy explained by a low-rank approximation. Despite a high lower bound, this fraction declines over training, and deeper layers generally exhibit smaller fractions.
  • Figure 2: Evolution of the top 20 singular values of the subspace estimation error derivative across different projection layers in the LLaMA-1B architecture. Each plot shows the maximum $i$-th singular value within a given layer type (aggregated across all 24 decoder layers) as training progresses. While MLP down-projection layers exhibit the largest singular values (g), their magnitude remains small and decays rapidly. Other projection layers (a-d) show values close to zero throughout training. The overall distribution of singular values becomes more uniform as training advances, suggesting that the gradient subspace evolves in an almost flat curvature.
  • Figure 3: We ablate (i) the subspace update method: Grassmannian tracking, Grassmannian random walk, random projections, and SVD; (ii) adaptive-optimizer (AO), and (iii) recovery scaling (RS), reporting evaluation loss (lower is better). The “No Subspace Update” variant freezes the initial SVD subspace $S_0$; because the subspace is fixed, AO is inapplicable and only RS is active.
  • Figure 4: Comparison of different methods on LLaMA pretraining. (a) Wall-clock training curves for LLaMA-1B across all baselines. (b) Pretraining results for LLaMA-7B across selected methods, excluding weaker baselines due to their large performance gap.