Table of Contents
Fetching ...

Data Selection via Optimal Control for Language Models

Yuxian Gu, Li Dong, Hongning Wang, Yaru Hao, Qingxiu Dong, Furu Wei, Minlie Huang

TL;DR

This work treats pre-training data selection for language models as a discrete Optimal Control problem and derives Pontryagin's Maximum Principle (PMP) conditions to characterize optimal data weights. Building on this theory, the authors propose PMP-based Data Selection (PDS), a practical offline pipeline that uses a proxy dataset to solve PMP conditions, learns a data scorer to transfer quality estimates to a full corpus, and selects a high-quality pre-training subset. Empirical results show that PDS accelerates LM learning and improves downstream performance across model sizes (from 160M to 1.7B) and tasks, while also extending benefits to extrapolated 400B-parameter models; in data-constrained settings, PDS reduces data demand by about 1.8x. The framework is offline, scalable, and transfer-friendly, enabling integration into optimized pre-training pipelines and offering a theory-grounded mechanism to understand the impact of individual data points on LM pre-training.

Abstract

This work investigates the selection of high-quality pre-training data from massive corpora to enhance LMs' capabilities for downstream usage. We formulate data selection as a generalized Optimal Control problem, which can be solved theoretically by Pontryagin's Maximum Principle (PMP), yielding a set of necessary conditions that characterize the relationship between optimal data selection and LM training dynamics. Based on these theoretical results, we introduce PMP-based Data Selection (PDS), a framework that approximates optimal data selection by solving the PMP conditions. In our experiments, we adopt PDS to select data from CommmonCrawl and show that the PDS-selected corpus accelerates the learning of LMs and constantly boosts their performance on a wide range of downstream tasks across various model sizes. Moreover, the benefits of PDS extend to ~400B models trained on ~10T tokens, as evidenced by the extrapolation of the test loss curves according to the Scaling Laws. PDS also improves data utilization when the pre-training data is limited, by reducing the data demand by 1.8 times, which helps mitigate the quick exhaustion of available web-crawled corpora. Our code, model, and data can be found at https://github.com/microsoft/LMOps/tree/main/data_selection.

Data Selection via Optimal Control for Language Models

TL;DR

This work treats pre-training data selection for language models as a discrete Optimal Control problem and derives Pontryagin's Maximum Principle (PMP) conditions to characterize optimal data weights. Building on this theory, the authors propose PMP-based Data Selection (PDS), a practical offline pipeline that uses a proxy dataset to solve PMP conditions, learns a data scorer to transfer quality estimates to a full corpus, and selects a high-quality pre-training subset. Empirical results show that PDS accelerates LM learning and improves downstream performance across model sizes (from 160M to 1.7B) and tasks, while also extending benefits to extrapolated 400B-parameter models; in data-constrained settings, PDS reduces data demand by about 1.8x. The framework is offline, scalable, and transfer-friendly, enabling integration into optimized pre-training pipelines and offering a theory-grounded mechanism to understand the impact of individual data points on LM pre-training.

Abstract

This work investigates the selection of high-quality pre-training data from massive corpora to enhance LMs' capabilities for downstream usage. We formulate data selection as a generalized Optimal Control problem, which can be solved theoretically by Pontryagin's Maximum Principle (PMP), yielding a set of necessary conditions that characterize the relationship between optimal data selection and LM training dynamics. Based on these theoretical results, we introduce PMP-based Data Selection (PDS), a framework that approximates optimal data selection by solving the PMP conditions. In our experiments, we adopt PDS to select data from CommmonCrawl and show that the PDS-selected corpus accelerates the learning of LMs and constantly boosts their performance on a wide range of downstream tasks across various model sizes. Moreover, the benefits of PDS extend to ~400B models trained on ~10T tokens, as evidenced by the extrapolation of the test loss curves according to the Scaling Laws. PDS also improves data utilization when the pre-training data is limited, by reducing the data demand by 1.8 times, which helps mitigate the quick exhaustion of available web-crawled corpora. Our code, model, and data can be found at https://github.com/microsoft/LMOps/tree/main/data_selection.

Paper Structure

This paper contains 66 sections, 3 theorems, 48 equations, 13 figures, 9 tables, 2 algorithms.

Key Result

Theorem 2.1

Let ${\bm{\gamma}}^*$ solve the problem in Eq. (eq:obj), and ${\bm{\theta}}^*_t$ denote the LM parameters trained with ${\bm{\gamma}}^*$. For $0\le t < T$, there exists a vector ${\bm{\lambda}}^*_t \in {\mathbb{R}}^N$ such that where $\nabla^2 L({\bm{\theta}}^*_t, {\bm{\gamma}}^*)$ denotes the Hessian matrix of $L({\bm{\theta}}, {\bm{\gamma}}^*)$ with respect to ${\bm{\theta}}$ evaluated at ${\bm

Figures (13)

  • Figure 1: Scaling curves of average accuracy on 9 widely-used downstream tasks with respect to computation (a) and model sizes (b). We select pre-training corpora from the CommonCrawl and pre-train LMs on the selected data. $\textsc{PDS}$ is compared with the Redpajama data cleaning pipeline redpajama. The computation curves are calculated based on the training of a 1.7B LM.
  • Figure 2: An illustration of Theorem \ref{['trm:pmp']}. Left: ${\bm{\lambda}}^*_{t+1} \in \mathbb{R}^N$ defines a "target vector" aligning with the optimization direction towards optimal data selection, as in Eq. (\ref{['eq:pmp_lam']}). Right: data quality scores are positively correlated with how close the gradient direction of each instance is to the target direction, calculated as the dot-product between ${\bm{\lambda}}^*_{t+1}$ and $\nabla l_{i,t} = \nabla l(x_i,{\bm{\theta}}^*_t)$ for $i=n,m,k$, as in Eq. (\ref{['eq:pmp_max']}).
  • Figure 3: The $\textsc{PDS}$ framework. We compute data quality scores ${\bm{\gamma}}^*$ on a proxy dataset $\mathcal{D}^\mathrm{prx}$ using Algorithm \ref{['alg:method']}, which is derived from the Pontryagin's Maximum Principle maximum_principle (Section \ref{['sec:pmp_solver']}). After that, the data scorer learns to predict quality scores from instances, which then infers scores for a large corpus $\mathcal{D}$ (Section \ref{['sec:classifier']}). Finally, a high-quality corpus $\mathcal{D}'$ is selected based on the inferred scores to pre-train an LM (Section \ref{['sec:data_selection']}).
  • Figure 4: Test losses on the DCLM corpus dclm for 160M, 470M, 1B and 1.7B LMs.
  • Figure 5: Test losses on DCLM corpus dclm in the data-constrained setting. We select data with $\textsc{PDS}$ for different selection ratios $r$ and train the model for multiple epochs to reach the same token number budgets.
  • ...and 8 more figures

Theorems & Definitions (3)

  • Theorem 2.1: PMP Conditions for Data Selection
  • Theorem B.1: PMP
  • Theorem C.1: PMP Data Selection for Adam