Table of Contents
Fetching ...

Can We Remove the Square-Root in Adaptive Gradient Methods? A Second-Order Perspective

Wu Lin, Felix Dangel, Runa Eschenhagen, Juhan Bae, Richard E. Turner, Alireza Makhzani

TL;DR

This work investigates how the behavior of adaptive methods changes when the authors remove the root, and finds that such square-root-free adaptive methods close the generalization gap to SGD on convolutional architectures, while maintaining their root-based counterpart's performance on transformers.

Abstract

Adaptive gradient optimizers like Adam(W) are the default training algorithms for many deep learning architectures, such as transformers. Their diagonal preconditioner is based on the gradient outer product which is incorporated into the parameter update via a square root. While these methods are often motivated as approximate second-order methods, the square root represents a fundamental difference. In this work, we investigate how the behavior of adaptive methods changes when we remove the root, i.e., strengthen their second-order motivation. Surprisingly, we find that such square-root-free adaptive methods close the generalization gap to SGD on convolutional architectures, while maintaining their root-based counterpart's performance on transformers. The second-order perspective also has practical benefits for developing non-diagonal methods that can incorporate arbitrary curvature approximations through the concept of preconditioner invariance. In contrast to root-based methods like Shampoo, root-free counterparts work well and fast with half-precision since they do not require numerically unstable matrix root decompositions and inversions. Overall, our findings provide new insights into the development of adaptive methods and raise important questions regarding the overlooked role of adaptivity in their success. (experiment code: https://github.com/yorkerlin/remove-the-square-root optimizer code: https://github.com/f-dangel/sirfshampoo)

Can We Remove the Square-Root in Adaptive Gradient Methods? A Second-Order Perspective

TL;DR

This work investigates how the behavior of adaptive methods changes when the authors remove the root, and finds that such square-root-free adaptive methods close the generalization gap to SGD on convolutional architectures, while maintaining their root-based counterpart's performance on transformers.

Abstract

Adaptive gradient optimizers like Adam(W) are the default training algorithms for many deep learning architectures, such as transformers. Their diagonal preconditioner is based on the gradient outer product which is incorporated into the parameter update via a square root. While these methods are often motivated as approximate second-order methods, the square root represents a fundamental difference. In this work, we investigate how the behavior of adaptive methods changes when we remove the root, i.e., strengthen their second-order motivation. Surprisingly, we find that such square-root-free adaptive methods close the generalization gap to SGD on convolutional architectures, while maintaining their root-based counterpart's performance on transformers. The second-order perspective also has practical benefits for developing non-diagonal methods that can incorporate arbitrary curvature approximations through the concept of preconditioner invariance. In contrast to root-based methods like Shampoo, root-free counterparts work well and fast with half-precision since they do not require numerically unstable matrix root decompositions and inversions. Overall, our findings provide new insights into the development of adaptive methods and raise important questions regarding the overlooked role of adaptivity in their success. (experiment code: https://github.com/yorkerlin/remove-the-square-root optimizer code: https://github.com/f-dangel/sirfshampoo)
Paper Structure (37 sections, 75 equations, 11 figures, 2 tables)

This paper contains 37 sections, 75 equations, 11 figures, 2 tables.

Figures (11)

  • Figure 1: In modern (pre-)training setups (learning rate schedule, random search using 200 runs), square-root-free (RF) adaptive methods close the generalization gap between their root-based counterparts and SGD on CNNs (CIFAR-100), while maintaining their performance on vision transformers (ImageWoof10). They work well on other problems, like training a 3-layer LSTM, and a GNN with attention zhang2022graph. Experimental setup, performance measurements, and fine-tuning experiments on vision models are described in \ref{['sec:app_exp_nn']}.
  • Figure 2: Comparison of root-based versus square-root-free (RF) methods in the original (outdated) training setup the root was introduced in (see \ref{['sec:app_exp_nn']} for details). Adaptive methods with the root work better than their root-free counterparts when using (1) a constant learning rate schedule, (2) default zero initialization for a preconditioner, (3) default scaling for an averaged mini-batch loss.
  • Figure 3: Comparison of matrix root-free versus root-based methods on GCViT hatamizadeh2023global, SwinViT liu2021swin, FocalNet yang2022focal, and VMamba liu2024vmamba. Both matrix methods (Shampoo, IF-Shampoo) outperform diagonal methods on modern models and achieve a lower test error using modern training strategies (random search using 200 runs). In contrast to Shampoo, our inverse-free matrix method, IF-Shampoo, runs in BFP-16 and trains twice as fast, while using less memory. Using low-precision data types bridges the computation gap between diagonal and matrix methods. All models are trained for 300 epochs. We update matrix preconditioners at every 2 iterations and can reduce the wall clock time by updating them less frequently. See App. \ref{['sec:app_exp_nn']} for more details.
  • Figure 4: Diagonal adaptive methods for a (scaled) loss function $\ell_{\text{scaled}}(\hbox{$\hbox{$\boldsymbol{\mu}$}$})$ defined by averaging over $B$ data points in a mini-batch case. The scalar $B$ highlighted in red is essential because we average the loss functions for mini-batch training. Hyperparameters will be hard to tune if the scalar is not included. This is because they will implicitly depend on the batch size $B$. We initialize $s$ to $1$ in our root-free method while $\hat{s}$ is initialized to $0$ in the original RMSProp. For simplicity, we do not include damping, weight decay, and momentum. A full-fledged version is in the Appendix, Fig. \ref{['fig:rmsprop-full']}.
  • Figure 5: Experiments demonstrating that square-root-free adaptive methods work well with both cross entropy (CE) and square error (SE) losses. Thus, our (diagonal) empirical Fisher estimation used in our update scheme does not suffer from the limitations of the standard empirical Fisher kunstner2019limitations. We train all models using mini-batches and use random search (200 runs) to tune these methods. In the first two plots on the left, we consider convex problems using a constant learning rate schedule (classical training scheme) considered by kunstner2019limitations. In the remaining plots, we consider non-convex NN problems with a step decay learning rate schedule (modern training scheme) considered by wilson2017marginal. We consider ResNet50 models and train them for 120 epochs with mini-batch size 128. Due to the large number of classes on the CIFAR100 dataset, we employ SE loss functions suggested by hui2020evaluation when it comes to using SE loss functions for classification tasks.
  • ...and 6 more figures

Theorems & Definitions (22)

  • Definition 1
  • Definition 2
  • Definition 3
  • Definition 4
  • Claim 1
  • Claim 2
  • Claim 3
  • Claim 4
  • proof
  • proof
  • ...and 12 more