Table of Contents
Fetching ...

SubTrack++ : Gradient Subspace Tracking for Scalable LLM Training

Sahar Rajabi, Nayeema Nonta, Sirisha Rambhatla

TL;DR

Large language model training is hindered by memory and time demands. SubTrack++ introduces Grassmannian gradient subspace tracking combined with a projection-aware optimizer and recovery scaling to enable full-parameter training within the same memory footprint while delivering substantial speedups. Theoretical convergence guarantees accompany empirical evidence showing state-of-the-art convergence and significant wall-time reductions across multiple model scales. This approach offers a practical, scalable path toward democratizing access to large models without sacrificing performance.

Abstract

Training large language models (LLMs) is highly resource-intensive due to their massive number of parameters and the overhead of optimizer states. While recent work has aimed to reduce memory consumption, such efforts often entail trade-offs among memory efficiency, training time, and model performance. Yet, true democratization of LLMs requires simultaneous progress across all three dimensions. To this end, we propose SubTrack++ that leverages Grassmannian gradient subspace tracking combined with projection-aware optimizers, enabling Adam's internal statistics to adapt to subspace changes. Additionally, employing recovery scaling, a technique that restores information lost through low-rank projections, further enhances model performance. Our method demonstrates SOTA convergence by exploiting Grassmannian geometry, reducing pre-training wall-time by up to 65% and fine-tuning time by 36% compared to existing SOTA methods, while maintaining the same memory footprint.

SubTrack++ : Gradient Subspace Tracking for Scalable LLM Training

TL;DR

Large language model training is hindered by memory and time demands. SubTrack++ introduces Grassmannian gradient subspace tracking combined with a projection-aware optimizer and recovery scaling to enable full-parameter training within the same memory footprint while delivering substantial speedups. Theoretical convergence guarantees accompany empirical evidence showing state-of-the-art convergence and significant wall-time reductions across multiple model scales. This approach offers a practical, scalable path toward democratizing access to large models without sacrificing performance.

Abstract

Training large language models (LLMs) is highly resource-intensive due to their massive number of parameters and the overhead of optimizer states. While recent work has aimed to reduce memory consumption, such efforts often entail trade-offs among memory efficiency, training time, and model performance. Yet, true democratization of LLMs requires simultaneous progress across all three dimensions. To this end, we propose SubTrack++ that leverages Grassmannian gradient subspace tracking combined with projection-aware optimizers, enabling Adam's internal statistics to adapt to subspace changes. Additionally, employing recovery scaling, a technique that restores information lost through low-rank projections, further enhances model performance. Our method demonstrates SOTA convergence by exploiting Grassmannian geometry, reducing pre-training wall-time by up to 65% and fine-tuning time by 36% compared to existing SOTA methods, while maintaining the same memory footprint.

Paper Structure

This paper contains 13 sections, 4 theorems, 54 equations, 6 figures, 12 tables, 1 algorithm.

Key Result

Theorem 3.2

Suppose gradient has the following form with functions $A_i$, $B_i$, and $C_i$ being L-continuous as per Def. def:cont with constants $L_A$, $L_B$, and $L_C$ w.r.t. weight matrix $W_t$; and $\|W_t\|_F \leq M$; where $W_t$ denotes the weight matrix at step $t$, and $M$ is a scalar value, Now, define $\widehat{B}_{i,t} = (S_{i, t}^l)^\top B_i(W_t) S_{i, t}^l$ and $\widehat{C}_{i,t} = (S_{i, t}^r)^\

Figures (6)

  • Figure 1: We compare baselines on pre-training a 1B-parameter model. (a) SubTrack++ achieves the lowest evaluation loss across all methods. (b) Its peak memory usage is significantly lower than APOLLO and LDAdam, and on par with GaLore and Fira. (c) In terms of wall-time, SubTrack++ incurs minimal overhead relative to APOLLO and is markedly faster than GaLore, Fira, and LDAdam. Overall, SubTrack++ outperforms all baselines in evaluation loss while matching or exceeding them in memory and runtime efficiency.
  • Figure 2: Visualization of Grassmannian subspace tracking: Between subspace updates, gradients are projected onto a fixed subspace. The tangent vector $\nabla F$ is computed via the derivative of a loss function, measuring the subspace estimation error. The subspace is then updated by moving along the corresponding geodesic, determined by $\nabla F$ to minimize estimation error.
  • Figure 3: Ablation study comparing pure Grassmannian subspace tracking with incremental additions of the projection-aware optimizer and recovery scaling, leading to SubTrack++ . While Grassmannian tracking alone almost matches GaLore’s step-wise convergence (a), it significantly reduces wall-time (b).
  • Figure 4: Ablation results on (a) update frequency: decreasing the update interval (i.e., increasing the frequency) improves evaluation performance up to a point, but overly frequent updates hinder training convergence. (b) update rank: increasing the rank of updates degrades model performance, and beyond a certain threshold, can prevent convergence. These results emphasize the importance of controlled subspace adjustments.
  • Figure 5: Comparison of Grassmannian subspace tracking (Ours) (a, c) and GaLore's SVD (b, d) on the Ackley Function over 100 optimization steps, with a subspace update interval of 10. SF stands for scale factor; with a scale factor of 1, GaLore fails to reach the global minimum due to abrupt jumps. At a scale factor of 3, while the minimum is reached, the jump length increases. This demonstrates SVD's sensitivity to noise and abrupt changes, highlighting the robustness of our subspace tracking method with its controlled subspace updates.
  • ...and 1 more figures

Theorems & Definitions (9)

  • Definition 3.1: L-continuity
  • Theorem 3.2: Convergence of Grassmannian Subspace Tracking
  • Definition 3.3: Exponential Map
  • Definition 3.4: Stiefel Manifold
  • Definition 3.5: Grassmann Manifold
  • Theorem 3.6: Grassmann Exponential
  • Theorem A.1: Convergence of Grassmannian Subspace Tracking
  • Theorem B.1: Grassmann Exponential
  • Definition B.1: Orthogonal Group