Table of Contents
Fetching ...

Grass: Compute Efficient Low-Memory LLM Training with Structured Sparse Gradients

Aashiq Muhamed, Oscar Li, David Woodruff, Mona Diab, Virginia Smith

TL;DR

Grass introduces a memory-efficient optimization framework for LLMs by using structured sparse gradient projections, reducing both optimizer-state and gradient memory while enabling full-parameter training. It leverages a sparse projection matrix P with P^\top = \rho B to compute projected gradients without materializing the full gradient, yielding substantial compute, memory, and communication savings. Empirical results across pretraining, finetuning, and instruction-finetuning demonstrate competitive performance to full-rank training and prior projection-based methods, while delivering significant throughput and memory benefits, including BF16 pretraining of a 13B LLaMA on a 40GB GPU and up to 2x multi-GPU throughput. Grass achieves these gains through principled stochastic/deterministic sparse-P constructions, efficient gradient projection, and careful distributed training integration, offering a practical path to scaling LLM training on modest hardware budgets.

Abstract

Large language model (LLM) training and finetuning are often bottlenecked by limited GPU memory. While existing projection-based optimization methods address this by projecting gradients into a lower-dimensional subspace to reduce optimizer state memory, they typically rely on dense projection matrices, which can introduce computational and memory overheads. In this work, we propose Grass (GRAdient Stuctured Sparsification), a novel approach that leverages sparse projections to transform gradients into structured sparse updates. This design not only significantly reduces memory usage for optimizer states but also minimizes gradient memory footprint, computation, and communication costs, leading to substantial throughput improvements. Extensive experiments on pretraining and finetuning tasks demonstrate that Grass achieves competitive performance to full-rank training and existing projection-based methods. Notably, Grass enables half-precision pretraining of a 13B parameter LLaMA model on a single 40GB A100 GPU--a feat infeasible for previous methods--and yields up to a $2\times$ throughput improvement on an 8-GPU system. Code can be found at https://github.com/aashiqmuhamed/GRASS .

Grass: Compute Efficient Low-Memory LLM Training with Structured Sparse Gradients

TL;DR

Grass introduces a memory-efficient optimization framework for LLMs by using structured sparse gradient projections, reducing both optimizer-state and gradient memory while enabling full-parameter training. It leverages a sparse projection matrix P with P^\top = \rho B to compute projected gradients without materializing the full gradient, yielding substantial compute, memory, and communication savings. Empirical results across pretraining, finetuning, and instruction-finetuning demonstrate competitive performance to full-rank training and prior projection-based methods, while delivering significant throughput and memory benefits, including BF16 pretraining of a 13B LLaMA on a 40GB GPU and up to 2x multi-GPU throughput. Grass achieves these gains through principled stochastic/deterministic sparse-P constructions, efficient gradient projection, and careful distributed training integration, offering a practical path to scaling LLM training on modest hardware budgets.

Abstract

Large language model (LLM) training and finetuning are often bottlenecked by limited GPU memory. While existing projection-based optimization methods address this by projecting gradients into a lower-dimensional subspace to reduce optimizer state memory, they typically rely on dense projection matrices, which can introduce computational and memory overheads. In this work, we propose Grass (GRAdient Stuctured Sparsification), a novel approach that leverages sparse projections to transform gradients into structured sparse updates. This design not only significantly reduces memory usage for optimizer states but also minimizes gradient memory footprint, computation, and communication costs, leading to substantial throughput improvements. Extensive experiments on pretraining and finetuning tasks demonstrate that Grass achieves competitive performance to full-rank training and existing projection-based methods. Notably, Grass enables half-precision pretraining of a 13B parameter LLaMA model on a single 40GB A100 GPU--a feat infeasible for previous methods--and yields up to a throughput improvement on an 8-GPU system. Code can be found at https://github.com/aashiqmuhamed/GRASS .

Paper Structure

This paper contains 81 sections, 4 theorems, 23 equations, 16 figures, 13 tables, 4 algorithms.

Key Result

Theorem D.1

Let $B \in \{0, 1\}^{r \times m}$ be the sparse binary matrix with the unique non-zero index of $j$-th row being $\sigma_j \in [m]$. Let $\sigma_j \stackrel{\textit{i.i.d.}}{\sim} \textrm{Multinomial}(1, q)$) ($q \in \mathbb{R}^m$ with the probability of sampling integer $k \in [m]$ being $q_k$). If

Figures (16)

  • Figure 1: Pretraining 1B LLaMA on 8.8B tokens of C4 with Grass, Full-rank and GaLore. (Left) Train perplexity vs seen tokens. (Right) Train perplexity vs wall-clock time. Grass outperforms GaLore and shows $<0.01$ perplexity gap with Full-rank loss curve in wall-clock time.
  • Figure 2: Normalized pretraining throughput at $r=64$ for Grass, Full-rank, and GaLore relative to Full-rank. Grass throughput exceeds Full and GaLore throughput by $>25\%$.
  • Figure 3: Pretraining memory footprint for Grass, GaLore, and Full across model sizes for a regular (non projection update step) and $r=128$. Grass has a lower memory footprint across all model sizes and the reduction is greater at larger model sizes.
  • Figure 4: Normalized LLaMA finetuning throughput of Grass, GaLore, and LoRA relative to LoRA. We use rank $r=64$. Grass is $>18\%$ faster than LoRA.
  • Figure 5: Communication Efficiency: Weak Scaling Throughput Comparison for 3B LLaMA pretraining using Grass, Full-rank, and GaLore. Grass shows $2\times$ higher throughput over Full and GaLore at 8 GPUs.
  • ...and 11 more figures

Theorems & Definitions (7)

  • Theorem D.1
  • proof
  • Theorem : Complete statement of Theorem \ref{['theorem:one']}
  • proof
  • Theorem F.1: Subspace Preservation
  • proof
  • Lemma F.2: Sampling in Orthogonal Spaces