Table of Contents
Fetching ...

Train Small, Infer Large: Memory-Efficient LoRA Training for Large Language Models

Jun Zhang, Jue Wang, Huan Li, Lidan Shou, Ke Chen, Yang You, Guiming Xie, Xuejian Gong, Kunlong Zhou

TL;DR

LoRAM introduces a memory-efficient approach to LoRA fine-tuning by training on a pruned, smaller model and recovering low-rank updates before merging with the full model for inference. It augments this workflow with a low-cost continual pre-training alignment to mitigate knowledge mismatch between the pruned and full models, enabling publishers to release aligned pruned variants. Across multiple pruning strategies, model sizes (e.g., LLaMA-2-13B/70B, LLaMA-3.1-70B), and downstream tasks, LoRAM achieves substantial memory reductions (up to $16.95\times$ parameter storage in QLoRAM) and competitive or superior inference performance, including notable gains in math and code domains. The work demonstrates practical feasibility of memory-efficient LoRA training on consumer hardware and outlines future directions for applying the idea to other domains such as vision transformers and diffusion models.

Abstract

Large Language Models (LLMs) have significantly advanced natural language processing with exceptional task generalization capabilities. Low-Rank Adaption (LoRA) offers a cost-effective fine-tuning solution, freezing the original model parameters and training only lightweight, low-rank adapter matrices. However, the memory footprint of LoRA is largely dominated by the original model parameters. To mitigate this, we propose LoRAM, a memory-efficient LoRA training scheme founded on the intuition that many neurons in over-parameterized LLMs have low training utility but are essential for inference. LoRAM presents a unique twist: it trains on a pruned (small) model to obtain pruned low-rank matrices, which are then recovered and utilized with the original (large) model for inference. Additionally, minimal-cost continual pre-training, performed by the model publishers in advance, aligns the knowledge discrepancy between pruned and original models. Our extensive experiments demonstrate the efficacy of LoRAM across various pruning strategies and downstream tasks. For a model with 70 billion parameters, LoRAM enables training on a GPU with only 20G HBM, replacing an A100-80G GPU for LoRA training and 15 GPUs for full fine-tuning. Specifically, QLoRAM implemented by structured pruning combined with 4-bit quantization, for LLaMA-3.1-70B (LLaMA-2-70B), reduces the parameter storage cost that dominates the memory usage in low-rank matrix training by 15.81$\times$ (16.95$\times$), while achieving dominant performance gains over both the original LLaMA-3.1-70B (LLaMA-2-70B) and LoRA-trained LLaMA-3.1-8B (LLaMA-2-13B). Code is available at https://github.com/junzhang-zj/LoRAM.

Train Small, Infer Large: Memory-Efficient LoRA Training for Large Language Models

TL;DR

LoRAM introduces a memory-efficient approach to LoRA fine-tuning by training on a pruned, smaller model and recovering low-rank updates before merging with the full model for inference. It augments this workflow with a low-cost continual pre-training alignment to mitigate knowledge mismatch between the pruned and full models, enabling publishers to release aligned pruned variants. Across multiple pruning strategies, model sizes (e.g., LLaMA-2-13B/70B, LLaMA-3.1-70B), and downstream tasks, LoRAM achieves substantial memory reductions (up to parameter storage in QLoRAM) and competitive or superior inference performance, including notable gains in math and code domains. The work demonstrates practical feasibility of memory-efficient LoRA training on consumer hardware and outlines future directions for applying the idea to other domains such as vision transformers and diffusion models.

Abstract

Large Language Models (LLMs) have significantly advanced natural language processing with exceptional task generalization capabilities. Low-Rank Adaption (LoRA) offers a cost-effective fine-tuning solution, freezing the original model parameters and training only lightweight, low-rank adapter matrices. However, the memory footprint of LoRA is largely dominated by the original model parameters. To mitigate this, we propose LoRAM, a memory-efficient LoRA training scheme founded on the intuition that many neurons in over-parameterized LLMs have low training utility but are essential for inference. LoRAM presents a unique twist: it trains on a pruned (small) model to obtain pruned low-rank matrices, which are then recovered and utilized with the original (large) model for inference. Additionally, minimal-cost continual pre-training, performed by the model publishers in advance, aligns the knowledge discrepancy between pruned and original models. Our extensive experiments demonstrate the efficacy of LoRAM across various pruning strategies and downstream tasks. For a model with 70 billion parameters, LoRAM enables training on a GPU with only 20G HBM, replacing an A100-80G GPU for LoRA training and 15 GPUs for full fine-tuning. Specifically, QLoRAM implemented by structured pruning combined with 4-bit quantization, for LLaMA-3.1-70B (LLaMA-2-70B), reduces the parameter storage cost that dominates the memory usage in low-rank matrix training by 15.81 (16.95), while achieving dominant performance gains over both the original LLaMA-3.1-70B (LLaMA-2-70B) and LoRA-trained LLaMA-3.1-8B (LLaMA-2-13B). Code is available at https://github.com/junzhang-zj/LoRAM.

Paper Structure

This paper contains 63 sections, 11 equations, 15 figures, 8 tables, 1 algorithm.

Figures (15)

  • Figure 1: Idea of LoRAM
  • Figure 2: Comparison of LoRAM and LoRA: Training (subfigures a and b) and Inference (c and d). Key stages include the offline process of the frozen full-rank matrix $\mathbf{W}_{0}^{*}$ (subfigure e) and the online generation of the learnable low-rank matrix $\mathbf{W}_{\Delta}^{*}$ (f) during LoRAM training (b) and inference (d).
  • Figure 3: The test perplexity of training LLaMA-2-13B & LLaMA-2-70B on OpenHermes.
  • Figure 4: The test perplexity of training LLaMA-2-13B & LLaMA-2-70B on OpenOrca.
  • Figure 5: The test perplexity & downstream performance of training LLaMA-3.1-70B on OpenHermes.
  • ...and 10 more figures