Table of Contents
Fetching ...

Optimizing ML Training with Metagradient Descent

Logan Engstrom, Andrew Ilyas, Benjamin Chen, Axel Feldmann, William Moses, Aleksander Madry

TL;DR

The paper tackles the problem of optimizing training configurations for large-scale models by treating metaparameters as continuous variables and using gradients to search the design space. It introduces Replay, a scalable method to compute exact metagradients through iterative training, and a metasmoothness framework to make these gradients informative for optimization. Applying metagradient descent (MGD) across data selection, data poisoning, and learning rate schedule discovery yields state-of-the-art improvements in multimodal data curation (DataComp CLIP), instruction-tuning data selection, and robust data poisoning, while also recovering competitive learning rate schedules. The work demonstrates that with metasmooth training and efficient gradient computation, gradient-based optimization of training configurations can scale to billions of parameters and real-world datasets, guiding practical improvements in model performance and robustness.

Abstract

A major challenge in training large-scale machine learning models is configuring the training process to maximize model performance, i.e., finding the best training setup from a vast design space. In this work, we unlock a gradient-based approach to this problem. We first introduce an algorithm for efficiently calculating metagradients -- gradients through model training -- at scale. We then introduce a "smooth model training" framework that enables effective optimization using metagradients. With metagradient descent (MGD), we greatly improve on existing dataset selection methods, outperform accuracy-degrading data poisoning attacks by an order of magnitude, and automatically find competitive learning rate schedules.

Optimizing ML Training with Metagradient Descent

TL;DR

The paper tackles the problem of optimizing training configurations for large-scale models by treating metaparameters as continuous variables and using gradients to search the design space. It introduces Replay, a scalable method to compute exact metagradients through iterative training, and a metasmoothness framework to make these gradients informative for optimization. Applying metagradient descent (MGD) across data selection, data poisoning, and learning rate schedule discovery yields state-of-the-art improvements in multimodal data curation (DataComp CLIP), instruction-tuning data selection, and robust data poisoning, while also recovering competitive learning rate schedules. The work demonstrates that with metasmooth training and efficient gradient computation, gradient-based optimization of training configurations can scale to billions of parameters and real-world datasets, guiding practical improvements in model performance and robustness.

Abstract

A major challenge in training large-scale machine learning models is configuring the training process to maximize model performance, i.e., finding the best training setup from a vast design space. In this work, we unlock a gradient-based approach to this problem. We first introduce an algorithm for efficiently calculating metagradients -- gradients through model training -- at scale. We then introduce a "smooth model training" framework that enables effective optimization using metagradients. With metagradient descent (MGD), we greatly improve on existing dataset selection methods, outperform accuracy-degrading data poisoning attacks by an order of magnitude, and automatically find competitive learning rate schedules.

Paper Structure

This paper contains 78 sections, 29 equations, 14 figures, 6 tables, 3 algorithms.

Figures (14)

  • Figure 1: Our proto-algorithm, metagradient descent (MGD), uses gradients to achieve state-of-the-art performance across a variety of applications, including data selection and data poisoning.
  • Figure 2: An illustration of the metagradient. We embed a given aspect of the training setup (e.g., the training dataset, or optimizer hyperparameters) into a continuous metaparameter vector $z \in \mathbb{R}^d$. This metaparameter defines a model $\mathcal{A}(z)$ by way of the learning algorithm $\mathcal{A}$, which in turn defines an output $\phi(z)$. The metagradient$\nabla_z \phi(\mathcal{A}(z)) \in \mathbb{R}^d$ is the gradient of this model output with respect to the metaparameter.
  • Figure 3: The lazy $k$-ary tree structure for traversing optimizer states in reverse order, with $k=2$. Recall that $n$ is the number of states (parameterized such that $n=T+1$). Each node represents the correspondingly numbered state. We give an example of the traversal using the blue arrows in the figure, which denote the traversal path up to state ${s_{\frac{3n}{4} + 1}}$. The gray cylinders indicate the states that are stored when the traversal is at state ${s_{\frac{3n}{4} + 1}}$; the other states are not stored at this point in the traversal. Traversing this structure requires storing $\mathcal{O}(\log(n))$ state and computing ${\mathcal{O}(n\log(n))}$ optimizer steps---compared to $n$ for simply training.
  • Figure 4: (a) For a variety of training configurations of a ResNet-9 model, we plot metasmoothness (Def. \ref{['def:avg_meta_smoothness']}) against test accuracy. Strategies such as increasing width, placing batch normalization before activations, and scaling down network outputs consistently improve metasmoothness, at a minor cost to accuracy. (b) Smoother training configurations can be optimized via metagradients more effectively. Here, as in Section \ref{['sec:poisoning']}, we use metagradients to gradient ascend on validation loss.
  • Figure 5: The effect of metasmoothness on the optimization landscape. Each plot above visualizes the loss landscape of a (deterministic) learning algorithm $\mathcal{A}$, with the $x$- and $y$-axes representing additive perturbations to 1000 examples in the training set and the $z$-axis representing the resulting model's loss on the test example given in the title. In each row, the left plot is a non-smooth algorithm, and the right plot is a smooth algorithm (as per Definition \ref{['def:avg_meta_smoothness']}) evaluated on the same example. Overall, empirical metasmoothness seems to strongly correlate with qualitative landscape smoothness. See Figure \ref{['fig:more_landscapes']} for more examples.
  • ...and 9 more figures

Theorems & Definitions (6)

  • Remark 1: Connection to rematerialization
  • Remark 2: Reversible learning
  • Definition 1: Metasmoothness of $f$ at $\mathbf{z}$ towards $\mathbf{v}$
  • Definition 2: Empirical metasmoothness of $\mathcal{A}$
  • Remark 3
  • Remark 4: Poisoning non-smooth learning algorithms