Table of Contents
Fetching ...

Efficient Ensemble for Fine-tuning Language Models on Multiple Datasets

Dongyue Li, Ziniu Zhang, Lu Wang, Hongyang R. Zhang

TL;DR

This work tackles robust parameter-efficient fine-tuning of large language models across multiple datasets by introducing an ensemble of small adapters, each trained for a group of tasks. Task affinities among datasets are estimated with a first-order Taylor expansion around the base model using gradients from a fixed initialization, producing a task affinity matrix $T$ that is clustered into $m$ groups. An adapter is trained per group and optionally boosted with a few gradient steps, then the ensemble outputs are combined with learned weights, yielding $M = m + b$ adapters. Empirically, the method achieves up to 10% higher average accuracy than QLoRA on SuperGLUE tasks with modest computational overhead (about 9 GB extra memory and 9% FLOPs) and scales to federated settings with hundreds of datasets, all while maintaining strong generalization as shown by Hessian-based sharpness analyses. Overall, the approach offers a practical, scalable path to multi-task fine-tuning that preserves efficiency while boosting performance.

Abstract

This paper develops an ensemble method for fine-tuning a language model to multiple datasets. Existing methods, such as quantized LoRA (QLoRA), are efficient when adapting to a single dataset. When training on multiple datasets of different tasks, a common setup in practice, it remains unclear how to design an efficient adaptation for fine-tuning language models. We propose to use an ensemble of multiple smaller adapters instead of a single adapter per task. We design an efficient algorithm that partitions $n$ datasets into $m$ groups, where $m$ is typically much smaller than $n$ in practice, and train one adapter for each group before taking a weighted combination to form the ensemble. The algorithm leverages a first-order approximation property of low-rank adaptation to quickly obtain the fine-tuning performances of dataset combinations since methods like LoRA stay close to the base model. Hence, we use the gradients of the base model to estimate its behavior during fine-tuning. Empirically, this approximation holds with less than $1\%$ error on models with up to $34$ billion parameters, leading to an estimation of true fine-tuning performances under $5\%$ error while speeding up computation compared to base fine-tuning by $105$ times. When applied to fine-tune Llama and GPT models on ten text classification tasks, our approach provides up to $10\%$ higher average test accuracy over QLoRA, with only $9\%$ more FLOPs. On a Llama model with $34$ billion parameters, an ensemble of QLoRA increases test accuracy by $3\%$ compared to QLoRA, with only $8\%$ more FLOPs.

Efficient Ensemble for Fine-tuning Language Models on Multiple Datasets

TL;DR

This work tackles robust parameter-efficient fine-tuning of large language models across multiple datasets by introducing an ensemble of small adapters, each trained for a group of tasks. Task affinities among datasets are estimated with a first-order Taylor expansion around the base model using gradients from a fixed initialization, producing a task affinity matrix that is clustered into groups. An adapter is trained per group and optionally boosted with a few gradient steps, then the ensemble outputs are combined with learned weights, yielding adapters. Empirically, the method achieves up to 10% higher average accuracy than QLoRA on SuperGLUE tasks with modest computational overhead (about 9 GB extra memory and 9% FLOPs) and scales to federated settings with hundreds of datasets, all while maintaining strong generalization as shown by Hessian-based sharpness analyses. Overall, the approach offers a practical, scalable path to multi-task fine-tuning that preserves efficiency while boosting performance.

Abstract

This paper develops an ensemble method for fine-tuning a language model to multiple datasets. Existing methods, such as quantized LoRA (QLoRA), are efficient when adapting to a single dataset. When training on multiple datasets of different tasks, a common setup in practice, it remains unclear how to design an efficient adaptation for fine-tuning language models. We propose to use an ensemble of multiple smaller adapters instead of a single adapter per task. We design an efficient algorithm that partitions datasets into groups, where is typically much smaller than in practice, and train one adapter for each group before taking a weighted combination to form the ensemble. The algorithm leverages a first-order approximation property of low-rank adaptation to quickly obtain the fine-tuning performances of dataset combinations since methods like LoRA stay close to the base model. Hence, we use the gradients of the base model to estimate its behavior during fine-tuning. Empirically, this approximation holds with less than error on models with up to billion parameters, leading to an estimation of true fine-tuning performances under error while speeding up computation compared to base fine-tuning by times. When applied to fine-tune Llama and GPT models on ten text classification tasks, our approach provides up to higher average test accuracy over QLoRA, with only more FLOPs. On a Llama model with billion parameters, an ensemble of QLoRA increases test accuracy by compared to QLoRA, with only more FLOPs.

Paper Structure

This paper contains 27 sections, 14 equations, 8 figures, 7 tables, 2 algorithms.

Figures (8)

  • Figure 1: Left: We propose an ensemble method for fine-tuning language models on multiple datasets. Given $n$ datasets and a base adaptation method such as LoRA, we design an ensemble of adapters that applies weighted averaging to their outputs, with minimal computation and memory overhead to the base method. Middle: We partition $n$ datasets into $m$ groups based on the task affinities scores. Our method first estimates fine-tuning performances on multiple dataset combinations, $S_1, S_2, \dots, S_k$, by evaluating the gradients of the adapter weights at $\theta^{\star}$. For each subset $S_i$, we estimate the fine-tuned adapter weights $\hat{\theta}_{S_i}$ by solving a regression problem, using the projected gradients within $S_i$ as features. This leads to an $n$ by $n$ task affinity matrix $T$, where $T_{i,j}$ is the affinity score computed from the estimates (see equation \ref{['eq_affinity_score']}). We then partition the $n$ datasets into $m$ groups with a clustering algorithm applied to $T$. In practice, $m$ is usually much smaller than $n$. Right: We design an adapter ensemble by fine-tuning one adapter per group and further refining it with a few gradient-boosting steps. This overall procedure incurs little computational overhead, as we will describe in Table \ref{['table_compare']}.
  • Figure 2: We report the approximation error for Llama and GPT-J models with up to 34 billion parameters for LoRA, QLoRA, adapter, and QAdapter. We report the average and the standard deviation based on the results from $50$ randomly sampled task subsets of size $3$.
  • Figure 3: We compare error rate (one minus accuracy), computation cost, and memory usage across our approach and baselines when fine-tuning Llama-3-8B on ten NLP tasks. MTL-FT refers to first fine-tuning a shared LoRA on all the datasets, and then fine-tuning the low-rank adapter on each dataset, while Full FT refers to full fine-tuning of the entire model. Our approach boosts the test accuracy of QLoRA by $10\%$ on average, only incurring $8\%$ additional computation and 9 GB more memory. It performs on par with the best baseline with $45\%$ less FLOPs.
  • Figure 4: Illustrating the empirical generalization errors and sharpness measures with respect to QLoRA weights. \ref{['fig_vary_rank']} Smaller adapters with a rank of $16$ achieve the lowest generalization errors. Additionally, the Hessian trace values correlate with generalization errors, suggesting that smaller adapters tend to converge to flatter minima. \ref{['fig_ensemble_size']} An ensemble of $k$ adapters leads to lower generalization errors and Hessian traces. Here, we fix the sum of the dimensions of $k$ adapters to be equal to $256$. \ref{['fig_vary_quantize']} QLoRA, which is trained on a quantized base model, yields lower generalization errors and Hessian trace values compared to LoRA.
  • Figure 5: This figure compares the error rate (one average minus test accuracy), computation cost, and GPU memory across our approach and baselines, for fine-tuning Llama-3-8B on ten NLP tasks with QAdapter.
  • ...and 3 more figures