Dynamic Gradient Alignment for Online Data Mixing
Simin Fan, David Grangier, Pierre Ablin
TL;DR
Dynamic Gradient Alignment (DGA) addresses the challenge of crafting an optimal data mixture for pretraining large language models when only a small specialized dataset is available. It performs online gradient alignment by updating domain weights on a simplex through gradient-facing rules and optional EMA stabilization, tying the inner generic training loss to the outer specialized objective in a bilevel framework. A scalable distribution-reweighting variant extends DGA to thousands of fine-grained domains by combining basis distributions, enabling effective data mixing without prohibitive gradient computation. Across token-constrained and large-domain scenarios, DGA outperforms static importance sampling and demonstrates robustness through EMA, with notable gains in language modeling and improved task-focused performance, though some improvements in LM do not always translate to reasoning tasks. The approach advances practical specialization of large models under data scarcity, offering a tractable path to targeted performance enhancement in real-world settings.
Abstract
The composition of training data mixtures is critical for effectively training large language models (LLMs), as it directly impacts their performance on downstream tasks. Our goal is to identify an optimal data mixture to specialize an LLM for a specific task with access to only a few examples. Traditional approaches to this problem include ad-hoc reweighting methods, importance sampling, and gradient alignment techniques. This paper focuses on gradient alignment and introduces Dynamic Gradient Alignment (DGA), a scalable online gradient alignment algorithm. DGA dynamically estimates the pre-training data mixture on which the models' gradients align as well as possible with those of the model on the specific task. DGA is the first gradient alignment approach that incurs minimal overhead compared to standard pre-training and outputs a competitive model, eliminating the need for retraining the model. Experimentally, we demonstrate significant improvements over importance sampling in two key scenarios: (i) when the pre-training set is small and importance sampling overfits due to limited data; and (ii) when there is insufficient specialized data, trapping importance sampling on narrow pockets of data. Our findings underscore the effectiveness of gradient alignment methods in optimizing training data mixtures, particularly in data-constrained environments, and offer a practical solution for enhancing LLM performance on specific tasks with limited data availability.
