Data Selection via Optimal Control for Language Models
Yuxian Gu, Li Dong, Hongning Wang, Yaru Hao, Qingxiu Dong, Furu Wei, Minlie Huang
TL;DR
This work treats pre-training data selection for language models as a discrete Optimal Control problem and derives Pontryagin's Maximum Principle (PMP) conditions to characterize optimal data weights. Building on this theory, the authors propose PMP-based Data Selection (PDS), a practical offline pipeline that uses a proxy dataset to solve PMP conditions, learns a data scorer to transfer quality estimates to a full corpus, and selects a high-quality pre-training subset. Empirical results show that PDS accelerates LM learning and improves downstream performance across model sizes (from 160M to 1.7B) and tasks, while also extending benefits to extrapolated 400B-parameter models; in data-constrained settings, PDS reduces data demand by about 1.8x. The framework is offline, scalable, and transfer-friendly, enabling integration into optimized pre-training pipelines and offering a theory-grounded mechanism to understand the impact of individual data points on LM pre-training.
Abstract
This work investigates the selection of high-quality pre-training data from massive corpora to enhance LMs' capabilities for downstream usage. We formulate data selection as a generalized Optimal Control problem, which can be solved theoretically by Pontryagin's Maximum Principle (PMP), yielding a set of necessary conditions that characterize the relationship between optimal data selection and LM training dynamics. Based on these theoretical results, we introduce PMP-based Data Selection (PDS), a framework that approximates optimal data selection by solving the PMP conditions. In our experiments, we adopt PDS to select data from CommmonCrawl and show that the PDS-selected corpus accelerates the learning of LMs and constantly boosts their performance on a wide range of downstream tasks across various model sizes. Moreover, the benefits of PDS extend to ~400B models trained on ~10T tokens, as evidenced by the extrapolation of the test loss curves according to the Scaling Laws. PDS also improves data utilization when the pre-training data is limited, by reducing the data demand by 1.8 times, which helps mitigate the quick exhaustion of available web-crawled corpora. Our code, model, and data can be found at https://github.com/microsoft/LMOps/tree/main/data_selection.
