Table of Contents
Fetching ...

BAdam: A Memory Efficient Full Parameter Optimization Method for Large Language Models

Qijun Luo, Hengxu Yu, Xiao Li

Abstract

This work presents BAdam, an optimization method that leverages the block coordinate descent (BCD) framework with Adam's update rule. BAdam offers a memory efficient approach to the full parameter finetuning of large language models. We conduct a theoretical convergence analysis for BAdam in the deterministic case. Experimentally, we apply BAdam to finetune the Llama 3-8B and Llama 3-70B models using a single RTX3090-24GB GPU and 4 A100-80GB GPUs, respectively. The results confirm BAdam's efficiency in terms of memory usage, running time, and optimization capability. Furthermore, the downstream performance evaluation based on MT-bench and math benchmarks shows that BAdam outperforms existing memory efficient baselines such as LoRA. It also demonstrates that BAdam can achieve comparable or even superior performance compared to Adam. Finally, the ablation study using SGD's update rule illustrates the suitability of BCD for finetuning LLMs. Our code can be easily integrated into any PyTorch-based codebase and is available at https://github.com/Ledzy/BAdam.

BAdam: A Memory Efficient Full Parameter Optimization Method for Large Language Models

Abstract

This work presents BAdam, an optimization method that leverages the block coordinate descent (BCD) framework with Adam's update rule. BAdam offers a memory efficient approach to the full parameter finetuning of large language models. We conduct a theoretical convergence analysis for BAdam in the deterministic case. Experimentally, we apply BAdam to finetune the Llama 3-8B and Llama 3-70B models using a single RTX3090-24GB GPU and 4 A100-80GB GPUs, respectively. The results confirm BAdam's efficiency in terms of memory usage, running time, and optimization capability. Furthermore, the downstream performance evaluation based on MT-bench and math benchmarks shows that BAdam outperforms existing memory efficient baselines such as LoRA. It also demonstrates that BAdam can achieve comparable or even superior performance compared to Adam. Finally, the ablation study using SGD's update rule illustrates the suitability of BCD for finetuning LLMs. Our code can be easily integrated into any PyTorch-based codebase and is available at https://github.com/Ledzy/BAdam.
Paper Structure (24 sections, 7 theorems, 29 equations, 5 figures, 12 tables)

This paper contains 24 sections, 7 theorems, 29 equations, 5 figures, 12 tables.

Key Result

Theorem 2.1

$\mathsf{BAdam}$ using deterministic gradients is a descent method, under certain commonly utilized conditions for analyzing block coordinate descent method and Adam. That is, after one block-epoch of updates for the whole model, we have Consequently, $\mathsf{BAdam}$ finds a $\delta$-approximate stationary point within $\mathcal{O}(\delta^{-2})$ iterations.

Figures (5)

  • Figure 1: Illustration of the proposed $\mathsf{BAdam}$, which is based on the block coordinate descent framework. Colors represent the states of the partitioned blocks in one block-epoch, including the active block, non-updated inactive blocks, and updated inactive blocks.
  • Figure 2: Optimization capability of $\mathsf{BAdam}$ for finetuning Llama 3-8B on Alpaca-GPT4 dataset. Left: Online training loss of LoRA and $\mathsf{BAdam}$. Middle: Cumulative explained variance of $\mathsf{BAdam}$'s learned perturbation to the 25th layer's up-proj matrix. Right: Effective rank of Adam's and $\mathsf{BAdam}$'s learned perturbations.
  • Figure 3: Ablation study for BCD variants and their full counterparts for finetuning Llama 3-8B on Alpaca-GPT4 dataset. Left and middle: Convergence behavior. Right: MT-bench scores.
  • Figure 4: Effect of ordering strategies and Adam steps $K$.
  • Figure 5: Loss of continue pretrain Llama 3.1-8B-Instruct on StarCoder-Python dataset using $\mathsf{BAdam}$.

Theorems & Definitions (13)

  • Theorem 2.1: informal
  • Corollary D.3: bounded adaptive step sizes
  • proof
  • Theorem D.4: descent method
  • Corollary D.5: first-order complexity
  • proof
  • Lemma D.6: approximate descent inequality for one block
  • proof
  • Lemma D.7: bound for error term
  • proof
  • ...and 3 more