Table of Contents
Fetching ...

Optimizing Pre-Training Data Mixtures with Mixtures of Data Expert Models

Lior Belenki, Alekh Agarwal, Tianze Shi, Kristina Toutanova

TL;DR

This work tackles the challenge of optimizing pre-training data mixtures for large language models by introducing Mixture of Data Experts (MDE), an ensemble-based proxy that estimates generalization losses for candidate mixtures using per-token probabilities from domain-specific experts. By treating MDE outputs as features in regression models (Linear, GBM, MTGP), the authors achieve substantial improvements in ranking and loss prediction accuracy, enabling more effective data-mix optimization at reduced computational cost. Their theoretical justification shows that the optimal mixture can be expressed as a weighted combination of domain-specific predictors, supporting the efficacy of the MDE approach. Empirically, MDE-assisted optimization on SlimPajama data yields better few-shot downstream performance across generation and ranking tasks, and the approach remains competitive as a standalone estimator, offering a practical, sample-efficient path to improved generalization for mid-to-large-scale language models.

Abstract

We propose a method to optimize language model pre-training data mixtures through efficient approximation of the cross-entropy loss corresponding to each candidate mixture via a Mixture of Data Experts (MDE). We use this approximation as a source of additional features in a regression model, trained from observations of model loss for a small number of mixtures. Experiments with Transformer decoder-only language models in the range of 70M to 1B parameters on the SlimPajama dataset show that our method achieves significantly better performance than approaches that train regression models using only the mixture rates as input features. Combining this improved optimization method with an objective that takes into account cross-entropy on end task data leads to superior performance on few-shot downstream evaluations. We also provide theoretical insights on why aggregation of data expert predictions can provide good approximations to model losses for data mixtures.

Optimizing Pre-Training Data Mixtures with Mixtures of Data Expert Models

TL;DR

This work tackles the challenge of optimizing pre-training data mixtures for large language models by introducing Mixture of Data Experts (MDE), an ensemble-based proxy that estimates generalization losses for candidate mixtures using per-token probabilities from domain-specific experts. By treating MDE outputs as features in regression models (Linear, GBM, MTGP), the authors achieve substantial improvements in ranking and loss prediction accuracy, enabling more effective data-mix optimization at reduced computational cost. Their theoretical justification shows that the optimal mixture can be expressed as a weighted combination of domain-specific predictors, supporting the efficacy of the MDE approach. Empirically, MDE-assisted optimization on SlimPajama data yields better few-shot downstream performance across generation and ranking tasks, and the approach remains competitive as a standalone estimator, offering a practical, sample-efficient path to improved generalization for mid-to-large-scale language models.

Abstract

We propose a method to optimize language model pre-training data mixtures through efficient approximation of the cross-entropy loss corresponding to each candidate mixture via a Mixture of Data Experts (MDE). We use this approximation as a source of additional features in a regression model, trained from observations of model loss for a small number of mixtures. Experiments with Transformer decoder-only language models in the range of 70M to 1B parameters on the SlimPajama dataset show that our method achieves significantly better performance than approaches that train regression models using only the mixture rates as input features. Combining this improved optimization method with an objective that takes into account cross-entropy on end task data leads to superior performance on few-shot downstream evaluations. We also provide theoretical insights on why aggregation of data expert predictions can provide good approximations to model losses for data mixtures.

Paper Structure

This paper contains 44 sections, 1 theorem, 10 equations, 8 figures, 11 tables, 1 algorithm.

Key Result

Proposition 3.1

For any $\lambda$ in the $k-1$-simplex, let $p^\star_\lambda = \arg\min_{p\in\mathcal{P}} L(p, \lambda)$ be the minimizer of the $\lambda$-weighted loss over all probability distributions. Then we have for any $(x,y)$: where we use the shorthand $p^\star_i$ for the minimizer of $L(p, D_i)$, the expected loss on domain $i$. The coefficients $\lambda_i'$ satisfy: $\lambda_i^{'}(x) \propto D_i(x)\la

Figures (8)

  • Figure 1: Illustration of our approach. Data experts $E_i$ are trained from individual pre-training mixture domains $D_i$. The per-token $p_{\text{MDE}}$ approximations are generated as a $\lambda$-weighted average of the probabilities predicted by the individual experts. Then, for each validation domain, the MDE feature is computed as the average of log-probability under $p_{\text{MDE}}$ across its tokens. Lastly, the mixture weights $\lambda$ and the MDE features are used to fit a regression model that maps $\lambda$ to predicted validation losses. The optimal set of weights are found by optimizing an objective function based on the regression model.
  • Figure 2: Mean squared error (MSE) and Spearman's rank correlation ($\rho$) on prediction of averaged loss over SlimPajama domains only (SP) and all (ET+SP) validation domains, using different regressors from prior work, and ones proposed in this work. Regressors are fitted using 25 train mixtures (except MDE that uses only 7 train mixtures), and evaluated with 48 held-out mixtures. MDE features bring large improvements across regressors.
  • Figure 3: Per-domain mean loss squared error for SlimPajama validation domains.
  • Figure 4: Pairwise ranking accuracy of 55 data mixtures (510M models trained to 50K steps) based on proxies of different size and number of training steps.
  • Figure 5: Spearman's rank correlation of sp validation domains as a function of number of training mixtures.
  • ...and 3 more figures

Theorems & Definitions (2)

  • Proposition 3.1
  • proof