Table of Contents
Fetching ...

Omni-Masked Gradient Descent: Memory-Efficient Optimization via Mask Traversal with Improved Convergence

Hui Yang, Tao Ren, Jinyang Jiang, Wan Tian, Yijie Peng

TL;DR

Omni-Masked Gradient Descent (OMGD) is a lightweight, plug-and-play approach that integrates seamlessly into most mainstream optimizers, yielding consistent improvements over competitive baselines in both fine-tuning and pre-training tasks.

Abstract

Memory-efficient optimization methods have recently gained increasing attention for scaling full-parameter training of large language models under the GPU-memory bottleneck. Existing approaches either lack clear convergence guarantees, or only achieve the standard ${\mathcal{O}}(ε^{-4})$ iteration complexity in the nonconvex settings. We propose Omni-Masked Gradient Descent (OMGD), an optimization method based on mask traversal for memory efficient training, and provide a nonconvex convergence analysis that establishes a strictly improved iteration complexity of $\tilde{\mathcal{O}}(ε^{-3})$ for finding an $ε$-approximate stationary point. Empirically, OMGD is a lightweight, plug-and-play approach that integrates seamlessly into most mainstream optimizers, yielding consistent improvements over competitive baselines in both fine-tuning and pre-training tasks.

Omni-Masked Gradient Descent: Memory-Efficient Optimization via Mask Traversal with Improved Convergence

TL;DR

Omni-Masked Gradient Descent (OMGD) is a lightweight, plug-and-play approach that integrates seamlessly into most mainstream optimizers, yielding consistent improvements over competitive baselines in both fine-tuning and pre-training tasks.

Abstract

Memory-efficient optimization methods have recently gained increasing attention for scaling full-parameter training of large language models under the GPU-memory bottleneck. Existing approaches either lack clear convergence guarantees, or only achieve the standard iteration complexity in the nonconvex settings. We propose Omni-Masked Gradient Descent (OMGD), an optimization method based on mask traversal for memory efficient training, and provide a nonconvex convergence analysis that establishes a strictly improved iteration complexity of for finding an -approximate stationary point. Empirically, OMGD is a lightweight, plug-and-play approach that integrates seamlessly into most mainstream optimizers, yielding consistent improvements over competitive baselines in both fine-tuning and pre-training tasks.
Paper Structure (39 sections, 12 theorems, 115 equations, 7 figures, 8 tables, 2 algorithms)

This paper contains 39 sections, 12 theorems, 115 equations, 7 figures, 8 tables, 2 algorithms.

Key Result

Lemma 4.4

If Assumptions low_boundedness-uniformly_control hold, then for any $\tau\ge0$ and $m>0$, the sequence $\{(S_t,z_t)\}$ generated by Algorithm alg:1 satisfies where the constants $C,\Phi$ depend on $C_1,C_2,M,N$.

Figures (7)

  • Figure 1: Illustration of the epochwise mask application in OMGD with $d=8,M=4,N=4$. ① Generated masks satisfy the condition given in \ref{['mask_requirement']}. ② Outer loop processes the $M$ masks sequentially, corresponding to $M$ consecutive epochs within one cycle. ③ Inner loop performs a full dataset pass for each mask, computing the masked gradient defined in \ref{['mask_gradient_form']} to update the model parameters.
  • Figure 2: Squared $L^2$ norm of overall error, decay term, data-reshuffle term, and compression-error term.
  • Figure 3: Test loss of fine-tuning ViT on CIFAR-10.
  • Figure 4: Training loss of fine-tuning RoBERTa-Base on CoLA.
  • Figure 5: Training loss of pre-training GPT-2-124M.
  • ...and 2 more figures

Theorems & Definitions (29)

  • Lemma 4.4
  • Lemma 4.5: Descent lemma
  • Theorem 4.6: Convergence rate, nonconvex case
  • Theorem 4.8: Convergence rate, $\mu$-PL condition
  • Proposition 4.9
  • Remark 4.10
  • Remark 4.11
  • Remark 4.12
  • Remark 5.1
  • Remark 5.2
  • ...and 19 more