Optimizing ML Training with Metagradient Descent
Logan Engstrom, Andrew Ilyas, Benjamin Chen, Axel Feldmann, William Moses, Aleksander Madry
TL;DR
The paper tackles the problem of optimizing training configurations for large-scale models by treating metaparameters as continuous variables and using gradients to search the design space. It introduces Replay, a scalable method to compute exact metagradients through iterative training, and a metasmoothness framework to make these gradients informative for optimization. Applying metagradient descent (MGD) across data selection, data poisoning, and learning rate schedule discovery yields state-of-the-art improvements in multimodal data curation (DataComp CLIP), instruction-tuning data selection, and robust data poisoning, while also recovering competitive learning rate schedules. The work demonstrates that with metasmooth training and efficient gradient computation, gradient-based optimization of training configurations can scale to billions of parameters and real-world datasets, guiding practical improvements in model performance and robustness.
Abstract
A major challenge in training large-scale machine learning models is configuring the training process to maximize model performance, i.e., finding the best training setup from a vast design space. In this work, we unlock a gradient-based approach to this problem. We first introduce an algorithm for efficiently calculating metagradients -- gradients through model training -- at scale. We then introduce a "smooth model training" framework that enables effective optimization using metagradients. With metagradient descent (MGD), we greatly improve on existing dataset selection methods, outperform accuracy-degrading data poisoning attacks by an order of magnitude, and automatically find competitive learning rate schedules.
