Table of Contents
Fetching ...

MISA: Memory-Efficient LLMs Optimization with Module-wise Importance Sampling

Yuxi Liu, Renjia Deng, Yutong He, Xue Wang, Tao Yao, Kun Yuan

TL;DR

MISA introduces module-wise importance sampling to address memory bottlenecks in fine-tuning large language models by partitioning transformer layers into fine-grained modules and sampling updates based on real-time gradient importance. The approach provides a provable reduction in gradient variance, a convergence rate of \mathcal{O}(1/\sqrt{NT})\u001b[0m under non-convex stochastic settings, and detailed memory analyses showing memory savings over layer-wise and PEFT baselines, especially for long-sequence tasks. Extensive experiments across commonsense and math reasoning, instruction tuning, and pre-training demonstrate that MISA consistently outperforms baselines under comparable memory budgets, while offering robust ablations and practical guidelines for hyperparameters. These results highlight MISA’s potential to enable scalable, memory-efficient fine-tuning of LLMs without sacrificing performance.

Abstract

The substantial memory demands of pre-training and fine-tuning large language models (LLMs) require memory-efficient optimization algorithms. One promising approach is layer-wise optimization, which treats each transformer block as a single layer and optimizes it sequentially, while freezing the other layers to save optimizer states and activations. Although effective, these methods ignore the varying importance of the modules within each layer, leading to suboptimal performance. Moreover, layer-wise sampling provides only limited memory savings, as at least one full layer must remain active during optimization. To overcome these limitations, we propose Module-wise Importance SAmpling (MISA), a novel method that divides each layer into smaller modules and assigns importance scores to each module. MISA uses a weighted random sampling mechanism to activate modules, provably reducing gradient variance compared to layer-wise sampling. Additionally, we establish an \(\mathcal{O}(1/\sqrt{K})\) convergence rate under non-convex and stochastic conditions, where $K$ is the total number of block updates, and provide a detailed memory analysis showcasing MISA's superiority over existing baseline methods. Experiments on diverse learning tasks validate the effectiveness of MISA. Source code is available at https://github.com/pkumelon/MISA.

MISA: Memory-Efficient LLMs Optimization with Module-wise Importance Sampling

TL;DR

MISA introduces module-wise importance sampling to address memory bottlenecks in fine-tuning large language models by partitioning transformer layers into fine-grained modules and sampling updates based on real-time gradient importance. The approach provides a provable reduction in gradient variance, a convergence rate of \mathcal{O}(1/\sqrt{NT})\u001b[0m under non-convex stochastic settings, and detailed memory analyses showing memory savings over layer-wise and PEFT baselines, especially for long-sequence tasks. Extensive experiments across commonsense and math reasoning, instruction tuning, and pre-training demonstrate that MISA consistently outperforms baselines under comparable memory budgets, while offering robust ablations and practical guidelines for hyperparameters. These results highlight MISA’s potential to enable scalable, memory-efficient fine-tuning of LLMs without sacrificing performance.

Abstract

The substantial memory demands of pre-training and fine-tuning large language models (LLMs) require memory-efficient optimization algorithms. One promising approach is layer-wise optimization, which treats each transformer block as a single layer and optimizes it sequentially, while freezing the other layers to save optimizer states and activations. Although effective, these methods ignore the varying importance of the modules within each layer, leading to suboptimal performance. Moreover, layer-wise sampling provides only limited memory savings, as at least one full layer must remain active during optimization. To overcome these limitations, we propose Module-wise Importance SAmpling (MISA), a novel method that divides each layer into smaller modules and assigns importance scores to each module. MISA uses a weighted random sampling mechanism to activate modules, provably reducing gradient variance compared to layer-wise sampling. Additionally, we establish an \(\mathcal{O}(1/\sqrt{K})\) convergence rate under non-convex and stochastic conditions, where is the total number of block updates, and provide a detailed memory analysis showcasing MISA's superiority over existing baseline methods. Experiments on diverse learning tasks validate the effectiveness of MISA. Source code is available at https://github.com/pkumelon/MISA.

Paper Structure

This paper contains 54 sections, 15 theorems, 131 equations, 11 figures, 24 tables, 3 algorithms.

Key Result

Proposition 1

The optimal solution to problem eq:prob:1:xue is given as follows which is the sampling probability of block $b$ at iteration $n$.

Figures (11)

  • Figure 1: The gradient norm of different modules in different layers when fine-tuning LLaMA3-8B.
  • Figure 2: Comparison of peak memory consumption using MISA and LoRA fine-tuning on LLaMA3-8B across various sequence length.
  • Figure 3: Validation loss of LISA, BAdam, and MISA across three epochs of fine-tuning Mistral-7B (left), LLaMA2-7B (middle) and TinyLLaMA (right) on the Alpaca-GPT4 dataset. The x-axis represents training time (minutes).
  • Figure 4: Pre-training dynamics for LLaMA 130M (left) and LLaMA 350M (right) on the C4 dataset.
  • Figure 5: Comparison of peak memory consumption on LLaMA3-8B and LLaMA3-70B. (c) used flash-attention technique.
  • ...and 6 more figures

Theorems & Definitions (30)

  • Proposition 1
  • Remark 1: Memory and Computation overhead
  • Remark 2: Clarification on "Layer", "Module", and "Block"
  • Proposition 2
  • Theorem 1: Informal
  • Remark 3
  • Proposition 3
  • proof
  • Proposition 4
  • proof
  • ...and 20 more