Table of Contents
Fetching ...

FlashOptim: Optimizers for Memory Efficient Training

Jose Javier Gonzalez Ortiz, Abhay Gupta, Chris Renard, Davis Blalock

TL;DR

This work introduces FlashOptim, a suite of optimizations that reduces per-parameter memory by over 50% while preserving model quality and API compatibility, and designs companding functions that greatly reduce the error in 8-bit optimizer state quantization.

Abstract

Standard mixed-precision training of neural networks requires many bytes of accelerator memory for each model parameter. These bytes reflect not just the parameter itself, but also its gradient and one or more optimizer state variables. With each of these values typically requiring 4 bytes, training even a 7 billion parameter model can be impractical for researchers with less than 100GB of accelerator memory. We introduce FlashOptim, a suite of optimizations that reduces per-parameter memory by over 50% while preserving model quality and API compatibility. Our approach introduces two key techniques. First, we improve master weight splitting by finding and exploiting a tight bound on its quantization error. Second, we design companding functions that greatly reduce the error in 8-bit optimizer state quantization. Together with 16-bit gradients, these techniques reduce AdamW memory from 16 bytes to 7 bytes per parameter, or 5 bytes with gradient release. They also cut model checkpoint sizes by more than half. Experiments with FlashOptim applied to SGD, AdamW, and Lion show no measurable quality degradation on any task from a collection of standard vision and language benchmarks, including Llama-3.1-8B finetuning.

FlashOptim: Optimizers for Memory Efficient Training

TL;DR

This work introduces FlashOptim, a suite of optimizations that reduces per-parameter memory by over 50% while preserving model quality and API compatibility, and designs companding functions that greatly reduce the error in 8-bit optimizer state quantization.

Abstract

Standard mixed-precision training of neural networks requires many bytes of accelerator memory for each model parameter. These bytes reflect not just the parameter itself, but also its gradient and one or more optimizer state variables. With each of these values typically requiring 4 bytes, training even a 7 billion parameter model can be impractical for researchers with less than 100GB of accelerator memory. We introduce FlashOptim, a suite of optimizations that reduces per-parameter memory by over 50% while preserving model quality and API compatibility. Our approach introduces two key techniques. First, we improve master weight splitting by finding and exploiting a tight bound on its quantization error. Second, we design companding functions that greatly reduce the error in 8-bit optimizer state quantization. Together with 16-bit gradients, these techniques reduce AdamW memory from 16 bytes to 7 bytes per parameter, or 5 bytes with gradient release. They also cut model checkpoint sizes by more than half. Experiments with FlashOptim applied to SGD, AdamW, and Lion show no measurable quality degradation on any task from a collection of standard vision and language benchmarks, including Llama-3.1-8B finetuning.
Paper Structure (24 sections, 4 equations, 8 figures, 8 tables, 6 algorithms)

This paper contains 24 sections, 4 equations, 8 figures, 8 tables, 6 algorithms.

Figures (8)

  • Figure 1: Memory breakdown for finetuning Llama-3.1-8B. FlashOptim reduces peak memory from 175 to 113 GiB by compressing parameters and optimizer states.
  • Figure 2: Training convergence. Comparison of training loss trajectories between reference optimizers and their FlashOptim variants. Both achieve nearly identical loss curves throughout training, demonstrating that our memory optimizations do not impact model quality.
  • Figure 3: FP32 Reconstruction Error. Comparison of FP32 reconstruction error for different weight compression schemes across exponent ranges for a target datatype of BF16 (top) and FP16 (bottom). Our ULP-based error correction achieves lower relative error particularly for small exponents. Denormal floating point ranges are indicated with vertical dotted lines.
  • Figure 4: Optimizer state quantization error. NMSE comparison between standard scaled integer quantization (Linear) and our companding approach for momentum ($m$) and variance ($v$) buffers across different optimizers and datasets. Companding reduces quantization error across all optimizer types and tensor types, with particularly large improvements for variance tensors.
  • Figure 5: Companding prevents training divergence. GPT-2 training with AdamW and quantized optimizer states: linear quantization (no companding) causes rapid divergence, while our companding approach maintains stable training dynamics.
  • ...and 3 more figures