PLUMAGE: Probabilistic Low rank Unbiased Min Variance Gradient Estimator for Efficient Large Model Training
Matan Haroush, Daniel Soudry
TL;DR
PLUMAGE addresses memory and connectivity bottlenecks in training billion-parameter LLMs by introducing a probabilistic, unbiased, minimum-variance low-rank gradient estimator with fixed rank. It uses a one-sided projection and a wheel-of-fortune sampling scheme to achieve exact $k$-sparse gradient estimates, while aligning the optimizer's first and second moments across projection updates to stabilize training. Empirical results show PLUMAGE closes the gap to full-rank optimization by substantial margins (e.g., a 33% reduction in full-rank gap on evaluation loss and a 28% reduction on GLUE) and outperforms GaLoRE in pre-training and fine-tuning tasks without extra hyperparameters beyond rank and update interval. The work demonstrates practical memory and computation savings, enabling more accessible large-model training and offering a foundation for future per-layer rank adaptation and adaptive projection strategies.
Abstract
Accelerator memory and networking constraints have emerged as dominant bottlenecks when training large language models LLMs with billions of parameters. Existing low rank gradient estimators such as GaLoRE and FLORA compress gradients and optimizer tensors by projecting weight gradients onto a rank r subspace, enabling LLM training on consumer hardware. Yet, these methods are either biased or subject to high estimator variance. Moreover, the optimizer state based on the first and second moments estimates expressed in the previous subspace becomes misaligned whenever the projection is updated, leading to instabilities during training. We propose PLUMAGE: Probabilistic Low rank Unbiased Minimum vAriance Gradient Estimator. PLUMAGE is a drop in replacement for existing low rank gradient estimators. It does not introduce new hyperparameters beyond the chosen rank r and the update interval. In addition, we resolve optimizer state misalignment issues to prevent spurious weight updates and enhance training stability. We empirically demonstrate that PLUMAGE shrinks the full rank optimization's gap over the pre training evaluation loss by 33% on average across models and the average training loss across the GLUE benchmark by 28% within a similar computational and memory footprint as GaloRE.
