Table of Contents
Fetching ...

MixMin: Finding Data Mixtures via Convex Minimization

Anvith Thudi, Evianne Rovers, Yangjun Ruan, Tristan Thrush, Chris J. Maddison

TL;DR

MixMin reframes data source mixing as a convex optimization problem that arises when model expressivity is high, enabling gradient-based discovery of optimal source weights using cheap proxy models. For CE and MSE losses with no covariate shift, the Bayes-optimal mixture reduces to a linear combination of per-source Bayes models, allowing a simple empirical objective over a target dataset. The method optimizes this objective via entropic descent on the simplex using proxies, then remixes data according to the learned weights. Across language modeling and chemistry tasks, MixMin yields consistent improvements with only about 1% of complete training compute spent on proxies, and the benefits transfer to larger models and larger pools of sources, highlighting a scalable data curation paradigm. Limitations include reliance on no covariate shift and on CE/MSE losses, suggesting avenues for extending the framework to broader settings.

Abstract

Modern machine learning pipelines are increasingly combining and mixing data from diverse and disparate sources, e.g., pre-training large language models. Yet, finding the optimal data mixture is a challenging and open problem. We formalize this data mixing problem as a bi-level objective: the best mixture is the one that would lead to the best model for a downstream objective. Unfortunately, this objective is generally intractable. In this paper, we make the observation that the bi-level data mixing objective becomes convex as our model class becomes larger. We develop and study a gradient-based approach for optimizing this convex objective, which we call MixMin, and test it on language modeling and chemistry tasks. MixMin was the only method that uniformly improved the data mixture in all our experiments. With MixMin, we improved the data mixture using less than 0.2% additional compute for a pythia-410M model trained on 8.2B tokens, resulting between 1-5% relative improvement to negative log likelihood on PIQA, ARC Easy, SciQ, and OpenWebMath. Crucially, we found that MixMin mixtures for smaller models improved training of larger models, suggesting that MixMin mixtures may be scale-invariant. When mixing bioassay data to train an XGBoost model, we saw improvements to average precision scores of 0.03-0.15.

MixMin: Finding Data Mixtures via Convex Minimization

TL;DR

MixMin reframes data source mixing as a convex optimization problem that arises when model expressivity is high, enabling gradient-based discovery of optimal source weights using cheap proxy models. For CE and MSE losses with no covariate shift, the Bayes-optimal mixture reduces to a linear combination of per-source Bayes models, allowing a simple empirical objective over a target dataset. The method optimizes this objective via entropic descent on the simplex using proxies, then remixes data according to the learned weights. Across language modeling and chemistry tasks, MixMin yields consistent improvements with only about 1% of complete training compute spent on proxies, and the benefits transfer to larger models and larger pools of sources, highlighting a scalable data curation paradigm. Limitations include reliance on no covariate shift and on CE/MSE losses, suggesting avenues for extending the framework to broader settings.

Abstract

Modern machine learning pipelines are increasingly combining and mixing data from diverse and disparate sources, e.g., pre-training large language models. Yet, finding the optimal data mixture is a challenging and open problem. We formalize this data mixing problem as a bi-level objective: the best mixture is the one that would lead to the best model for a downstream objective. Unfortunately, this objective is generally intractable. In this paper, we make the observation that the bi-level data mixing objective becomes convex as our model class becomes larger. We develop and study a gradient-based approach for optimizing this convex objective, which we call MixMin, and test it on language modeling and chemistry tasks. MixMin was the only method that uniformly improved the data mixture in all our experiments. With MixMin, we improved the data mixture using less than 0.2% additional compute for a pythia-410M model trained on 8.2B tokens, resulting between 1-5% relative improvement to negative log likelihood on PIQA, ARC Easy, SciQ, and OpenWebMath. Crucially, we found that MixMin mixtures for smaller models improved training of larger models, suggesting that MixMin mixtures may be scale-invariant. When mixing bioassay data to train an XGBoost model, we saw improvements to average precision scores of 0.03-0.15.

Paper Structure

This paper contains 37 sections, 2 theorems, 10 equations, 18 figures, 5 tables, 1 algorithm.

Key Result

Theorem 3.1

Let the objective for eq:DM be unconditional CE, $\int -log(f(x)) dp(x)$, or conditional CE or MSE with no covariate shift, $\int -log(f^{y}(x)) dp(x,y)$ and $\int ||f(x) -y||_2^2 dp(x,y)$, with $p(x) = p’(x)~\forall p,p’ \in P$. Suppose also $\mathcal{H}$ contains the Bayes optimal model for each m

Figures (18)

  • Figure 1: Optimizing MixMin to find the mixture weights requires training a few cheap models for each source and a target dataset. Given the mixture weights we train a more expensive model using the mixture.
  • Figure 2: The convex MixMin objective better approximates the \ref{['eq:DM']} objective as the model class becomes larger (and better approximates Bayes optimal).
  • Figure 3: MixMin consistently outperforms all baselines across the four target tasks using $1\%$ of the final training run compute (Left). We report improvement over the downstream generative loss of training on the natural distribution (which is stated beside the task name): higher is better. Error bars indicate a $95\%$ confidence interval. Furthermore, we find that MixMin was robust to using less compute (Right), while RegMix and random search had their performance degrade with less compute.
  • Figure 4: MixMin mixtures derived from small models continue to improve training for larger models as measured by target accuracy (in most cases). We report accuracy with the errors bars representing a $95\%$ confidence interval over $3$ trials. MixMin weights were found using $1\%$ the compute of the $160M$ model training run, which is $~0.15\%$ the compute of the $410M$ training run.
  • Figure 5: MixMin mixtures derived from small models continue to improve training for larger models as measured by target generative loss (in all cases). We report generative loss with the errors bars representing a $95\%$ confidence interval over $3$ trials (lower is better). MixMin weights were found using $1\%$ the compute of the $160M$ model training run, which is $~0.15\%$ the compute of the $410M$ training run.
  • ...and 13 more figures

Theorems & Definitions (4)

  • Theorem 3.1
  • Lemma 3.2
  • proof
  • proof