Table of Contents
Fetching ...

Optimal low-rank stochastic gradient estimation for LLM training

Zehao Li, Tao Ren, Zishi Zhang, Xi Chen, Yijie Peng

Abstract

Large language model (LLM) training is often bottlenecked by memory constraints and stochastic gradient noise in extremely high-dimensional parameter spaces. Motivated by empirical evidence that many LLM gradient matrices are effectively low-rank during training, we present an unbiased, memory-efficient, low-rank matrix estimator with the lowest variance that is applicable across common stochastic gradient estimation paradigms. The core idea is to project a high-dimensional stochastic gradient estimator onto a random low-dimensional subspace and lift it back, reducing memory while keeping the estimator unbiased and controlling mean-squared error via an optimally designed projection distribution, including Haar--Stiefel projections. The projection distribution is derived by solving a constrained functional optimization problem, yielding an optimal random projector that guides algorithm design. Empirically, the resulting low-rank gradient estimators deliver both practical memory savings and improved training behavior. In RoBERTa-large fine-tuning, our method attains the lowest peak GPU memory among compared methods (e.g., 3.83GB versus 16.7GB for full BP) while remaining competitive in accuracy; in autoregressive LLM pretraining (LLaMA-20M/60M/100M), our method outperforms the traditional methods, supporting the benefit of the proposed optimal projection strategy.

Optimal low-rank stochastic gradient estimation for LLM training

Abstract

Large language model (LLM) training is often bottlenecked by memory constraints and stochastic gradient noise in extremely high-dimensional parameter spaces. Motivated by empirical evidence that many LLM gradient matrices are effectively low-rank during training, we present an unbiased, memory-efficient, low-rank matrix estimator with the lowest variance that is applicable across common stochastic gradient estimation paradigms. The core idea is to project a high-dimensional stochastic gradient estimator onto a random low-dimensional subspace and lift it back, reducing memory while keeping the estimator unbiased and controlling mean-squared error via an optimally designed projection distribution, including Haar--Stiefel projections. The projection distribution is derived by solving a constrained functional optimization problem, yielding an optimal random projector that guides algorithm design. Empirically, the resulting low-rank gradient estimators deliver both practical memory savings and improved training behavior. In RoBERTa-large fine-tuning, our method attains the lowest peak GPU memory among compared methods (e.g., 3.83GB versus 16.7GB for full BP) while remaining competitive in accuracy; in autoregressive LLM pretraining (LLaMA-20M/60M/100M), our method outperforms the traditional methods, supporting the benefit of the proposed optimal projection strategy.
Paper Structure (19 sections, 7 theorems, 113 equations, 9 figures, 3 tables, 4 algorithms)

This paper contains 19 sections, 7 theorems, 113 equations, 9 figures, 3 tables, 4 algorithms.

Key Result

Theorem 1

Fix a parameter block $\Theta\in\mathbb R^{m\times n}$ and let $V\in\mathbb R^{n\times r}$ be a random projection matrix independent of the data randomness $\xi$. The random subspace projection $V$ is sampled from a distribution in $\mathcal{D}$ in Equation dist. Assume the standard requirements for In particular, if $c=1$ then $\hat{g}_{\mathrm{LowRank\text{-}IPA}}$ and $\hat{g}_{\mathrm{LowRank\

Figures (9)

  • Figure 1: Illustration of the proposed low-rank gradient estimator and the lazy-update gradient descent framework. The gradient estimator with respect to $\Theta$ is computed via a rank-$r$ reparameterization using an auxiliary variable $B$ and a randomly sampled projection matrix $V$ (yellow trapezoids), and then lifted back to the original parameter space by multiplying $V^\top$ (blue block). In the lazy-update scheme, the same projection direction $V_0$ is reused for $K$ inner steps before switching to a new projection direction $V_1$ (green trapezoids).
  • Figure 2: MSE versus samples plot of independent Likelihood Ratio estimator
  • Figure 3: MSE versus samples plot of independent Infinitesimal Perturbation Analysis estimator
  • Figure 4: MSE versus samples plot of dependent Likelihood Ratio (LR) estimator
  • Figure 5: MSE versus samples plot of dependent Infinitesimal Perturbation Analysis (IPA) estimator
  • ...and 4 more figures

Theorems & Definitions (14)

  • Example 1: IPA gradient estimator for a one-layer ReLU network
  • Example 2: Special case of LR Gradient Estimator: ZO optimization
  • Definition 1
  • Definition 2: Low-rank Gradient Estimators
  • Definition 3
  • Theorem 1: Unbiasedness of the low-rank estimators
  • Example 3: Low-rank gradient estimators of feed-forward neural networks
  • Proposition 1
  • Theorem 2: Optimal instance-independent low-rank random projector
  • Proposition 2: Constructions
  • ...and 4 more