Table of Contents
Fetching ...

Dynamic Gradient Alignment for Online Data Mixing

Simin Fan, David Grangier, Pierre Ablin

TL;DR

Dynamic Gradient Alignment (DGA) addresses the challenge of crafting an optimal data mixture for pretraining large language models when only a small specialized dataset is available. It performs online gradient alignment by updating domain weights on a simplex through gradient-facing rules and optional EMA stabilization, tying the inner generic training loss to the outer specialized objective in a bilevel framework. A scalable distribution-reweighting variant extends DGA to thousands of fine-grained domains by combining basis distributions, enabling effective data mixing without prohibitive gradient computation. Across token-constrained and large-domain scenarios, DGA outperforms static importance sampling and demonstrates robustness through EMA, with notable gains in language modeling and improved task-focused performance, though some improvements in LM do not always translate to reasoning tasks. The approach advances practical specialization of large models under data scarcity, offering a tractable path to targeted performance enhancement in real-world settings.

Abstract

The composition of training data mixtures is critical for effectively training large language models (LLMs), as it directly impacts their performance on downstream tasks. Our goal is to identify an optimal data mixture to specialize an LLM for a specific task with access to only a few examples. Traditional approaches to this problem include ad-hoc reweighting methods, importance sampling, and gradient alignment techniques. This paper focuses on gradient alignment and introduces Dynamic Gradient Alignment (DGA), a scalable online gradient alignment algorithm. DGA dynamically estimates the pre-training data mixture on which the models' gradients align as well as possible with those of the model on the specific task. DGA is the first gradient alignment approach that incurs minimal overhead compared to standard pre-training and outputs a competitive model, eliminating the need for retraining the model. Experimentally, we demonstrate significant improvements over importance sampling in two key scenarios: (i) when the pre-training set is small and importance sampling overfits due to limited data; and (ii) when there is insufficient specialized data, trapping importance sampling on narrow pockets of data. Our findings underscore the effectiveness of gradient alignment methods in optimizing training data mixtures, particularly in data-constrained environments, and offer a practical solution for enhancing LLM performance on specific tasks with limited data availability.

Dynamic Gradient Alignment for Online Data Mixing

TL;DR

Dynamic Gradient Alignment (DGA) addresses the challenge of crafting an optimal data mixture for pretraining large language models when only a small specialized dataset is available. It performs online gradient alignment by updating domain weights on a simplex through gradient-facing rules and optional EMA stabilization, tying the inner generic training loss to the outer specialized objective in a bilevel framework. A scalable distribution-reweighting variant extends DGA to thousands of fine-grained domains by combining basis distributions, enabling effective data mixing without prohibitive gradient computation. Across token-constrained and large-domain scenarios, DGA outperforms static importance sampling and demonstrates robustness through EMA, with notable gains in language modeling and improved task-focused performance, though some improvements in LM do not always translate to reasoning tasks. The approach advances practical specialization of large models under data scarcity, offering a tractable path to targeted performance enhancement in real-world settings.

Abstract

The composition of training data mixtures is critical for effectively training large language models (LLMs), as it directly impacts their performance on downstream tasks. Our goal is to identify an optimal data mixture to specialize an LLM for a specific task with access to only a few examples. Traditional approaches to this problem include ad-hoc reweighting methods, importance sampling, and gradient alignment techniques. This paper focuses on gradient alignment and introduces Dynamic Gradient Alignment (DGA), a scalable online gradient alignment algorithm. DGA dynamically estimates the pre-training data mixture on which the models' gradients align as well as possible with those of the model on the specific task. DGA is the first gradient alignment approach that incurs minimal overhead compared to standard pre-training and outputs a competitive model, eliminating the need for retraining the model. Experimentally, we demonstrate significant improvements over importance sampling in two key scenarios: (i) when the pre-training set is small and importance sampling overfits due to limited data; and (ii) when there is insufficient specialized data, trapping importance sampling on narrow pockets of data. Our findings underscore the effectiveness of gradient alignment methods in optimizing training data mixtures, particularly in data-constrained environments, and offer a practical solution for enhancing LLM performance on specific tasks with limited data availability.
Paper Structure (22 sections, 1 theorem, 9 equations, 12 figures, 4 tables, 2 algorithms)

This paper contains 22 sections, 1 theorem, 9 equations, 12 figures, 4 tables, 2 algorithms.

Key Result

Theorem 1

Assume that there exists $\tilde{{\bm{\alpha}}}$ such that $D_{\mathrm{spe}} = \mathrm{mix}(\tilde{{\bm{\alpha}}})$ . Then, $\tilde{{\bm{\alpha}}}$ is a solution to the bilevel problem in equ:bilevel-form.

Figures (12)

  • Figure 1: Comparing data reweighting methods with free_law as a specific set in a low generic data regime. When there are not enough tokens, importance sampling quickly overfits, while DGA manages to explore the training distributions to avoid overfitting. We see the importance of the EMA to stabilize DGA in the low data regime. When there is no token limit, adding an EMA ($\beta=0.1$) does not negatively affect the performance.
  • Figure 2: Distribution reweighting experiment.
  • Figure 3: The top row presents the specific loss over time, with the two bottom rows illustrating the evolution of domain (dist.) weights from DGA correspondingly, with each line representing a distinct domain. Left: Weights from the limited generic token experiment (\ref{['sec:token_limit']}). Middle and Right: Weights from the distribution reweighting experiment (\ref{['sec:expe_fine_domains']}). The thick black line highlights the dynamic weights assigned by DGA on the MMLU importance sampling distribution, which serves as a fixed training distribution for the importance sampling runs.
  • Figure 4: Impact of the generic set granularity for the distribution reweighting experiment (\ref{['sec:expe_fine_domains']}. We report the specific loss obtained after training for different granularities of the base clustering.
  • Figure 5: Results on all the domains for the low data experiment (\ref{['sec:token_limit']}). The specific domain is free_law.
  • ...and 7 more figures

Theorems & Definitions (2)

  • Theorem 1
  • proof