Table of Contents
Fetching ...

The AdEMAMix Optimizer: Better, Faster, Older

Matteo Pagliardini, Pierre Ablin, David Grangier

TL;DR

This work questions the use of a single EMA to accumulate past gradients and empirically demonstrates how this choice can be sub-optimal: a single EMA cannot simultaneously give a high weight to the immediate past, and a non-negligible weight to older gradients.

Abstract

Momentum based optimizers are central to a wide range of machine learning applications. These typically rely on an Exponential Moving Average (EMA) of gradients, which decays exponentially the present contribution of older gradients. This accounts for gradients being local linear approximations which lose their relevance as the iterate moves along the loss landscape. This work questions the use of a single EMA to accumulate past gradients and empirically demonstrates how this choice can be sub-optimal: a single EMA cannot simultaneously give a high weight to the immediate past, and a non-negligible weight to older gradients. Building on this observation, we propose AdEMAMix, a simple modification of the Adam optimizer with a mixture of two EMAs to better take advantage of past gradients. Our experiments on language modeling and image classification show -- quite surprisingly -- that gradients can stay relevant for tens of thousands of steps. They help to converge faster, and often to lower minima: e.g., a $1.3$B parameter AdEMAMix LLM trained on $101$B tokens performs comparably to an AdamW model trained on $197$B tokens ($+95\%$). Moreover, our method significantly slows-down model forgetting during training. Our work motivates further exploration of different types of functions to leverage past gradients, beyond EMAs.

The AdEMAMix Optimizer: Better, Faster, Older

TL;DR

This work questions the use of a single EMA to accumulate past gradients and empirically demonstrates how this choice can be sub-optimal: a single EMA cannot simultaneously give a high weight to the immediate past, and a non-negligible weight to older gradients.

Abstract

Momentum based optimizers are central to a wide range of machine learning applications. These typically rely on an Exponential Moving Average (EMA) of gradients, which decays exponentially the present contribution of older gradients. This accounts for gradients being local linear approximations which lose their relevance as the iterate moves along the loss landscape. This work questions the use of a single EMA to accumulate past gradients and empirically demonstrates how this choice can be sub-optimal: a single EMA cannot simultaneously give a high weight to the immediate past, and a non-negligible weight to older gradients. Building on this observation, we propose AdEMAMix, a simple modification of the Adam optimizer with a mixture of two EMAs to better take advantage of past gradients. Our experiments on language modeling and image classification show -- quite surprisingly -- that gradients can stay relevant for tens of thousands of steps. They help to converge faster, and often to lower minima: e.g., a B parameter AdEMAMix LLM trained on B tokens performs comparably to an AdamW model trained on B tokens (). Moreover, our method significantly slows-down model forgetting during training. Our work motivates further exploration of different types of functions to leverage past gradients, beyond EMAs.
Paper Structure (33 sections, 18 equations, 33 figures, 8 tables, 1 algorithm)

This paper contains 33 sections, 18 equations, 33 figures, 8 tables, 1 algorithm.

Figures (33)

  • Figure 1: Comparing AdamW and AdEMAMix on language modeling. In (a,b,c), we plot the loss obtained using AdamW and AdEMAMix (our optimizer) to train Transformer models of various sizes on the Redpajama dataset. In (a), we train multiple baselines for $256k$, $400k$, and $500k$ iterations, resulting in processing from $17$B to $33$B tokens. Two AdamW runs with different number of iterations look very different as we use a cosine-decay for the learning rate. We compare those baselines to training AdEMAMix for $256k$ iterations. We observe that our method reaches a similar loss as an AdamW model trained on nearly twice the number of tokens. Analogous comparisons can be derived from (b) and (c). Notably, in (c), a $1.3$B parameter AdEMAMix model trained on $101$B tokens performs comparably to an AdamW model trained on $197$B tokens ($95\%$ more, blue horizontal line). See § \ref{['sec:llm-results']} and App. \ref{['app:llm_exp']} for a detailed description of our experimental setup, including hyperparameters.
  • Figure 2: Comparing Adam and AdEMAMix on the Rosenbrock function. Starting from ${\bm{x}}^{(0)}=[-3,5]$, we minimize the Rosenbrock function $f(x_1,x_2)=(1-x_1)^2+100 (x_2-x_1^2)^2$. The global minimum ($\mathbf{\star}$) is ${\bm{x}}^{\star}=[1,1]$. We use $\beta_2=0.999$ for Adam and $(\beta_1,\beta_2,\alpha)=(0.9,0.999,9)$ for AdEMAMix (see § \ref{['sec:method']}). We reduce the learning rate for AdEMAMix to compensate for the influence of $\alpha$. We do a sweep over $\beta_1$ (resp. $\beta_3$) for Adam (resp. for AdEMAMix). In (b), When Adam's $\beta_1$ is small (e.g. $0.9$), the iterates do not oscillate but convergence is slow. Increasing $\beta_1$ makes the iterates move faster but with large oscillations. In contrast, for AdEMAMix in (c), we observe that despite $\beta_3$ being large, the iterates moves fast and without oscillations. This results in reaching better solutions faster as can be seen in (a).
  • Figure 3: Limitation of EMAs, constant $\eta$-scheduler, & Mamba results. In (a), we plot the weights $w_t$---for each past gradient ${\bm{g}}^{(t)}$---given by different EMAs after $10k$ steps. For a given $\beta$, $\text{EMA}(\beta, {\bm{g}}^{(0)}, \ldots, {\bm{g}}^{(T)}) = \sum_{i=0}^T \beta^i (1-\beta) {\bm{g}}^{(T-i)}$, which decays the contribution of past gradients exponentially. A small $\beta$ (e.g. $0.9$) will give a high weight to the immediate past and negligible weights to older timesteps. In contrast, a high $\beta$ (e.g. $0.9999$) is giving a relatively uniform, yet non-negligible weight to all gradients. No $\beta$ value can simultaneously give a high weight to the immediate past and a non-negligible weight to very old timesteps. In (b), we train multiple $1.3$B language models using $3k$ steps of warmup and then a constant learning rate $\eta=10^{-4}$. This allows us to observe the gap between AdamW and AdEMAMix without the cosine decay as a confounder. We still use schedulers for $\alpha$ and $\beta_3$ with $T_{\alpha,\beta_3}=500k$, $\alpha=5$. Similar to Zhai0HB22minicpmalexlrdecay, we decay the learning rate linearly at $t=1$M and $t=1.3$M. The loss-gap between AdamW and AdEMAMix increases at first, and then remains constant. AdEMAMix still outperforms AdamW after decaying the learning rate. See App.\ref{['app:misc']} to see the impact of the linear decay duration. In (c), we train $168$M parameter Mamba models, showing how AdEMAMix's performances can generalize outside of the Transformer architecture.
  • Figure 4: Measuring forgetting using a held-out batch $B$. The top row is for AdamW, the bottom row is for AdEMAMix. We trained one AdamW and AdEMAMix model on a RedPajama dataset not containing the batch $B$, those runs are in blue. We then run multiple experiments where we inject $B$ in the training data at a specific timestep $t_B$. Those runs are in orange. To inspect how much influence $B$ had when it is injected at $t_B$, we can observe the evolution of the gap between the blue and the orange curves. For both optimizers, we observe a rapid decrease of the loss on $B$ right after training on $B$. The sharpness of this decrease in loss is more pronounced for AdamW compared to AdEMAMix. However, when using AdamW, the loss on $B$ then increases faster, which we interpret as the model forgetting $B$ faster. In contrast, the curves for AdEMAMix are smoother, the loss on $B$ goes back up slower, and ultimately $B$ had a bigger impact on the training when using AdEMAMix---as can be seen by looking at the larger gap between the orange and blue curves for the last iteration. Finally, the forgetting behaviour evolve during training, with the later training batches being remembered better. See App. \ref{['app:forgetting']} for a more detailed analysis of forgetting as training progresses.
  • Figure 5: Training time comparison & starting AdEMAMix from AdamW. In (a), we compare the time required to train $110$M and $1.3$B parameter models for $256k$ iterations. The additional EMA renders AdEMAMix slightly slower than AdamW. However, if we were to train AdamW longer to compensate for this gap, we would only train for an additional $2379$ and $5421$ iterations for resp. $110$M and $1.3$B parameter models. Those additional iterations would not be sufficient to close the gap (see Fig. \ref{['fig:llm_results']}). In (b) and (c), we show---for two different model sizes---the effect of switching from AdamW to AdEMAMix during training. AdEMAMix's additional parameter ${\bm{m}}_2$ is initialized to $\mathbf{0}$, no scheduler is used for $\alpha$ or $\beta_3$. For both model sizes, we observe the loss increases slightly at first before decreasing and outperforming the baseline. In both cases, the earlier AdEMAMix is used, the better the final loss. See App. \ref{['app:misc']} for results using schedulers.
  • ...and 28 more figures