Efficient Ensemble for Fine-tuning Language Models on Multiple Datasets
Dongyue Li, Ziniu Zhang, Lu Wang, Hongyang R. Zhang
TL;DR
This work tackles robust parameter-efficient fine-tuning of large language models across multiple datasets by introducing an ensemble of small adapters, each trained for a group of tasks. Task affinities among datasets are estimated with a first-order Taylor expansion around the base model using gradients from a fixed initialization, producing a task affinity matrix $T$ that is clustered into $m$ groups. An adapter is trained per group and optionally boosted with a few gradient steps, then the ensemble outputs are combined with learned weights, yielding $M = m + b$ adapters. Empirically, the method achieves up to 10% higher average accuracy than QLoRA on SuperGLUE tasks with modest computational overhead (about 9 GB extra memory and 9% FLOPs) and scales to federated settings with hundreds of datasets, all while maintaining strong generalization as shown by Hessian-based sharpness analyses. Overall, the approach offers a practical, scalable path to multi-task fine-tuning that preserves efficiency while boosting performance.
Abstract
This paper develops an ensemble method for fine-tuning a language model to multiple datasets. Existing methods, such as quantized LoRA (QLoRA), are efficient when adapting to a single dataset. When training on multiple datasets of different tasks, a common setup in practice, it remains unclear how to design an efficient adaptation for fine-tuning language models. We propose to use an ensemble of multiple smaller adapters instead of a single adapter per task. We design an efficient algorithm that partitions $n$ datasets into $m$ groups, where $m$ is typically much smaller than $n$ in practice, and train one adapter for each group before taking a weighted combination to form the ensemble. The algorithm leverages a first-order approximation property of low-rank adaptation to quickly obtain the fine-tuning performances of dataset combinations since methods like LoRA stay close to the base model. Hence, we use the gradients of the base model to estimate its behavior during fine-tuning. Empirically, this approximation holds with less than $1\%$ error on models with up to $34$ billion parameters, leading to an estimation of true fine-tuning performances under $5\%$ error while speeding up computation compared to base fine-tuning by $105$ times. When applied to fine-tune Llama and GPT models on ten text classification tasks, our approach provides up to $10\%$ higher average test accuracy over QLoRA, with only $9\%$ more FLOPs. On a Llama model with $34$ billion parameters, an ensemble of QLoRA increases test accuracy by $3\%$ compared to QLoRA, with only $8\%$ more FLOPs.
