Table of Contents
Fetching ...

Mini-batch Coresets for Memory-efficient Language Model Training on Data Mixtures

Dang Nguyen, Wenhan Yang, Rathul Anand, Yu Yang, Baharan Mirzasoleiman

TL;DR

The paper introduces CoLM, a data-centric approach to memory-efficient fine-tuning of large language models by constructing mini-batch coresets that approximate gradients from larger batches. It tackles three core challenges—imbalanced data sources, Adam optimization, and high gradient dimensionality—via including all small-source examples, normalizing gradients with historical averages for Adam, and using zeroth-order gradient estimates with sparsification to enable efficient medoid selection. Empirical results on MathInstruct and SuperGLUE show CoLM reduces activation memory by about 2x and can outperform training with 4x larger mini-batches, while also speeding up training and integrating with LoRA. The method is demonstrated to generalize across multiple model families and data settings, offering a practical path to memory-efficient, high-performance fine-tuning of LLMs.

Abstract

Training with larger mini-batches improves the convergence rate and can yield superior performance. However, training with large mini-batches becomes prohibitive for Large Language Models (LLMs), due to the large GPU memory requirement. To address this problem, an effective approach is finding small mini-batch coresets that closely match the gradient of larger mini-batches. However, this approach becomes infeasible and ineffective for LLMs, due to the highly imbalanced mixture of sources in language data, use of the Adam optimizer, and the very large gradient dimensionality of LLMs. In this work, we address the above challenges by proposing Coresets for Training LLMs (CoLM). First, we show that mini-batch coresets found by gradient matching do not contain representative examples of the small sources w.h.p., and thus including all examples of the small sources in the mini-batch coresets is crucial for optimal performance. Second, we normalize the gradients by their historical exponential to find mini-batch coresets for training with Adam. Finally, we leverage zeroth-order methods to find smooth gradient of the last V-projection matrix and sparsify it to keep the dimensions with the largest normalized gradient magnitude. We apply CoLM to fine-tuning Phi-2, Phi-3, Zephyr, and Llama-3 models with LoRA on MathInstruct and SuperGLUE benchmark. Remarkably, CoLM reduces the memory requirement of fine-tuning by 2x and even outperforms training with 4x larger mini-batches. Moreover, CoLM seamlessly integrates with existing memory-efficient training methods like LoRA, further reducing the memory requirements of training LLMs. Our code is available at https://github.com/BigML-CS-UCLA/CoLM.

Mini-batch Coresets for Memory-efficient Language Model Training on Data Mixtures

TL;DR

The paper introduces CoLM, a data-centric approach to memory-efficient fine-tuning of large language models by constructing mini-batch coresets that approximate gradients from larger batches. It tackles three core challenges—imbalanced data sources, Adam optimization, and high gradient dimensionality—via including all small-source examples, normalizing gradients with historical averages for Adam, and using zeroth-order gradient estimates with sparsification to enable efficient medoid selection. Empirical results on MathInstruct and SuperGLUE show CoLM reduces activation memory by about 2x and can outperform training with 4x larger mini-batches, while also speeding up training and integrating with LoRA. The method is demonstrated to generalize across multiple model families and data settings, offering a practical path to memory-efficient, high-performance fine-tuning of LLMs.

Abstract

Training with larger mini-batches improves the convergence rate and can yield superior performance. However, training with large mini-batches becomes prohibitive for Large Language Models (LLMs), due to the large GPU memory requirement. To address this problem, an effective approach is finding small mini-batch coresets that closely match the gradient of larger mini-batches. However, this approach becomes infeasible and ineffective for LLMs, due to the highly imbalanced mixture of sources in language data, use of the Adam optimizer, and the very large gradient dimensionality of LLMs. In this work, we address the above challenges by proposing Coresets for Training LLMs (CoLM). First, we show that mini-batch coresets found by gradient matching do not contain representative examples of the small sources w.h.p., and thus including all examples of the small sources in the mini-batch coresets is crucial for optimal performance. Second, we normalize the gradients by their historical exponential to find mini-batch coresets for training with Adam. Finally, we leverage zeroth-order methods to find smooth gradient of the last V-projection matrix and sparsify it to keep the dimensions with the largest normalized gradient magnitude. We apply CoLM to fine-tuning Phi-2, Phi-3, Zephyr, and Llama-3 models with LoRA on MathInstruct and SuperGLUE benchmark. Remarkably, CoLM reduces the memory requirement of fine-tuning by 2x and even outperforms training with 4x larger mini-batches. Moreover, CoLM seamlessly integrates with existing memory-efficient training methods like LoRA, further reducing the memory requirements of training LLMs. Our code is available at https://github.com/BigML-CS-UCLA/CoLM.
Paper Structure (21 sections, 10 theorems, 19 equations, 6 figures, 9 tables, 1 algorithm)

This paper contains 21 sections, 10 theorems, 19 equations, 6 figures, 9 tables, 1 algorithm.

Key Result

Theorem 4.1

Let examples in $V_q$ be partitioned into $m$ parts. A number of examples $|V_q| \geq \frac{2km \log(km / \delta)}{\beta g(\alpha)}$, where $\alpha \leq \alpha^\star$, is suffice to (1) have at least $km \log(km / \delta)$ elements in the $\alpha$-neighborhood of each $i \in A_q^t$ and (2) have each

Figures (6)

  • Figure 1: A toy imbalance data. (Left) Full data $V$ with two big (blue, green) and one small sources (purple). $k=3$ medoids of the data are shown in red. (Middle & right) Two random samples of the data, with their corresponding $k=3$ medoids. The $\alpha^\star$-neighborhoods of big sources are dense and thus medoids of random samples contain central examples of the big sources. However, the medoids of random sample do not necessarily contain central examples of the small source.
  • Figure 2: (a) CoLM with bs = 64 (from 128) outperforms fine-tuning different models with bs = 64 and bs = 128 by a large margin; (b) CoLM improves the performance of training with different batch sizes. The size of each circle is proportional to the training time of the corresponding method. (c) CoLM reduces memory consumption, with reduction increasing as the batch size grows.
  • Figure 3: Fine-tuning Phi-2 on MathInstruct. (a) Wall-clock time (including the time for CoLM's selection), memory consumption, and performance of fine-tuning. CoLM outperforms normal fine-tuning for 1K iterations with bs = 128 (256), while being 1.3x (2.7x) faster and consuming 20% (45%) less memory, respectively; (b) CoLM has a smaller variance than random mini-batches of the same size; (c) CoLM converges much faster than normal fine-tuning (FT).
  • Figure 4: (a) Data distribution of different data sources in MathInstruct. (b) The average completion length of examples selected by CoLM vs. random examples and longest examples in random batches.
  • Figure 5: Fine-tuning Phi-2 on MathInstruct. (a) CoLM yields smallest loss throughout the whole training process; (b) CoLM reaches the optimal performance in less training time.
  • ...and 1 more figures

Theorems & Definitions (15)

  • Theorem 4.1
  • Theorem 4.2
  • Theorem 4.3: Variance reduction
  • Lemma A.4
  • proof
  • Lemma A.5: Sampling with replacement
  • proof
  • Lemma A.6: Sampling without replacement
  • proof
  • Corollary A.7: Bound for local optimal solution
  • ...and 5 more