Table of Contents
Fetching ...

Sparse MeZO: Less Parameters for Better Performance in Zeroth-Order LLM Fine-Tuning

Yong Liu, Zirui Zhu, Chaoyu Gong, Minhao Cheng, Cho-Jui Hsieh, Yang You

TL;DR

Sparse-MeZO tackles memory constraints in fine-tuning large language models by applying zeroth-order optimization selectively to a sparsified subnetwork. It reveals that gradient noise in ZO estimates disproportionately harms large weights, motivating a percentile-based dynamic masking strategy and a memory-efficient forward-pass implementation. The approach yields consistent performance gains and faster convergence across LLaMA, OPT, and Mistral on SuperGLUE and related tasks, with substantial memory savings enabling training on a single GPU for large models. Theoretical convergence guarantees support the empirical gains, showing that focusing updates on a compact subnetwork can accelerate optimization without increasing memory overhead.

Abstract

While fine-tuning large language models (LLMs) for specific tasks often yields impressive results, it comes at the cost of memory inefficiency due to back-propagation in gradient-based training. Memory-efficient Zeroth-order (MeZO) optimizers, recently proposed to address this issue, only require forward passes during training, making them more memory-friendly. However, compared with exact gradients, ZO-based gradients usually exhibit an estimation error, which can significantly hurt the optimization process, leading to slower convergence and suboptimal solutions. In addition, we find that the estimation error will hurt more when adding to large weights instead of small weights. Based on this observation, this paper introduces Sparse MeZO, a novel memory-efficient zeroth-order optimization approach that applies ZO only to a carefully chosen subset of parameters. We propose a simple yet effective parameter selection scheme that yields significant performance gains with Sparse-MeZO. Additionally, we develop a memory-optimized implementation for sparse masking, ensuring the algorithm requires only inference-level memory consumption, allowing Sparse-MeZO to fine-tune LLaMA-30b on a single A100 GPU. Experimental results illustrate that Sparse-MeZO consistently improves both performance and convergence speed over MeZO without any overhead. For example, it achieves a 9\% absolute accuracy improvement and 3.5x speedup over MeZO on the RTE task. Code is available at https://github.com/NUS-HPC-AI-Lab/SparseMeZO.

Sparse MeZO: Less Parameters for Better Performance in Zeroth-Order LLM Fine-Tuning

TL;DR

Sparse-MeZO tackles memory constraints in fine-tuning large language models by applying zeroth-order optimization selectively to a sparsified subnetwork. It reveals that gradient noise in ZO estimates disproportionately harms large weights, motivating a percentile-based dynamic masking strategy and a memory-efficient forward-pass implementation. The approach yields consistent performance gains and faster convergence across LLaMA, OPT, and Mistral on SuperGLUE and related tasks, with substantial memory savings enabling training on a single GPU for large models. Theoretical convergence guarantees support the empirical gains, showing that focusing updates on a compact subnetwork can accelerate optimization without increasing memory overhead.

Abstract

While fine-tuning large language models (LLMs) for specific tasks often yields impressive results, it comes at the cost of memory inefficiency due to back-propagation in gradient-based training. Memory-efficient Zeroth-order (MeZO) optimizers, recently proposed to address this issue, only require forward passes during training, making them more memory-friendly. However, compared with exact gradients, ZO-based gradients usually exhibit an estimation error, which can significantly hurt the optimization process, leading to slower convergence and suboptimal solutions. In addition, we find that the estimation error will hurt more when adding to large weights instead of small weights. Based on this observation, this paper introduces Sparse MeZO, a novel memory-efficient zeroth-order optimization approach that applies ZO only to a carefully chosen subset of parameters. We propose a simple yet effective parameter selection scheme that yields significant performance gains with Sparse-MeZO. Additionally, we develop a memory-optimized implementation for sparse masking, ensuring the algorithm requires only inference-level memory consumption, allowing Sparse-MeZO to fine-tune LLaMA-30b on a single A100 GPU. Experimental results illustrate that Sparse-MeZO consistently improves both performance and convergence speed over MeZO without any overhead. For example, it achieves a 9\% absolute accuracy improvement and 3.5x speedup over MeZO on the RTE task. Code is available at https://github.com/NUS-HPC-AI-Lab/SparseMeZO.
Paper Structure (33 sections, 3 theorems, 27 equations, 4 figures, 13 tables, 3 algorithms)

This paper contains 33 sections, 3 theorems, 27 equations, 4 figures, 13 tables, 3 algorithms.

Key Result

Lemma 1

ZO gradient $\bm{g_{\hat{z}}(\bm{\theta})}$ is unbiased estimation of $\widehat{\nabla}_{\bm{\theta}} \mathcal{L}_{\hat{z}}(\bm{\theta})$:

Figures (4)

  • Figure 1: Performance of MeZO and Sparse-MeZO (S-MeZO) on RTE task. S-MeZO can achieve $3.5$x speedup compared with MeZO.
  • Figure 2: (a) Test Accuracy with Different Learning Rates on RTE Task. We find MeZO is very sensitive to the selection of learning rate. Even a small increase from $1 \times 10^{-6}$ to $2 \times 10^{-6}$ causes divergence and instability. (b) Probability of Loss Increase on Different Batch. We find the estimated ZO gradient can successfully reduce the loss on the same batch but may be difficult to decrease the loss on the new held-out batch. (c) Continuing training from the drop point with small and large weights. We find that optimizing only small weights can recover and further improve test accuracy.
  • Figure 3: Convergence Curves of Fine-Tuning LLaMA-7b with MeZO and Sparse-MeZO (S-MeZO) on (a) RTE, (b) BoolQ, (c) WIC tasks.
  • Figure 4: (a) Probability of Loss Increase with MeZO on Different Batch. (b) Probability of Loss Increase with SGD on Different Batch. We calculate the probability of loss increment for each epoch.

Theorems & Definitions (3)

  • Lemma 1
  • Lemma 2
  • Theorem 1