Table of Contents
Fetching ...

Kronecker-Factored Approximate Curvature for Modern Neural Network Architectures

Runa Eschenhagen, Alexander Immer, Richard E. Turner, Frank Schneider, Philipp Hennig

TL;DR

The paper addresses the challenge of applying second-order optimization via $K$-FAC to modern neural networks that use linear weight-sharing. It develops a framework distinguishing two settings, expand and reduce, and proves exactness for deep linear networks in each setting, while offering practical, faster approximations. Empirically, the authors demonstrate meaningful speedups and competitive optimization performance on a Wide ResNet/CIFAR-10 setup, a graph neural network on ogbg-molpcba, and a vision transformer on ImageNet, including an efficient use of marginal likelihood-based hyperparameter tuning. The work indicates that $K$-FAC can be extended to contemporary architectures and used to accelerate training and automatic hyperparameter selection in practice.

Abstract

The core components of many modern neural network architectures, such as transformers, convolutional, or graph neural networks, can be expressed as linear layers with $\textit{weight-sharing}$. Kronecker-Factored Approximate Curvature (K-FAC), a second-order optimisation method, has shown promise to speed up neural network training and thereby reduce computational costs. However, there is currently no framework to apply it to generic architectures, specifically ones with linear weight-sharing layers. In this work, we identify two different settings of linear weight-sharing layers which motivate two flavours of K-FAC -- $\textit{expand}$ and $\textit{reduce}$. We show that they are exact for deep linear networks with weight-sharing in their respective setting. Notably, K-FAC-reduce is generally faster than K-FAC-expand, which we leverage to speed up automatic hyperparameter selection via optimising the marginal likelihood for a Wide ResNet. Finally, we observe little difference between these two K-FAC variations when using them to train both a graph neural network and a vision transformer. However, both variations are able to reach a fixed validation metric target in $50$-$75\%$ of the number of steps of a first-order reference run, which translates into a comparable improvement in wall-clock time. This highlights the potential of applying K-FAC to modern neural network architectures.

Kronecker-Factored Approximate Curvature for Modern Neural Network Architectures

TL;DR

The paper addresses the challenge of applying second-order optimization via -FAC to modern neural networks that use linear weight-sharing. It develops a framework distinguishing two settings, expand and reduce, and proves exactness for deep linear networks in each setting, while offering practical, faster approximations. Empirically, the authors demonstrate meaningful speedups and competitive optimization performance on a Wide ResNet/CIFAR-10 setup, a graph neural network on ogbg-molpcba, and a vision transformer on ImageNet, including an efficient use of marginal likelihood-based hyperparameter tuning. The work indicates that -FAC can be extended to contemporary architectures and used to accelerate training and automatic hyperparameter selection in practice.

Abstract

The core components of many modern neural network architectures, such as transformers, convolutional, or graph neural networks, can be expressed as linear layers with . Kronecker-Factored Approximate Curvature (K-FAC), a second-order optimisation method, has shown promise to speed up neural network training and thereby reduce computational costs. However, there is currently no framework to apply it to generic architectures, specifically ones with linear weight-sharing layers. In this work, we identify two different settings of linear weight-sharing layers which motivate two flavours of K-FAC -- and . We show that they are exact for deep linear networks with weight-sharing in their respective setting. Notably, K-FAC-reduce is generally faster than K-FAC-expand, which we leverage to speed up automatic hyperparameter selection via optimising the marginal likelihood for a Wide ResNet. Finally, we observe little difference between these two K-FAC variations when using them to train both a graph neural network and a vision transformer. However, both variations are able to reach a fixed validation metric target in - of the number of steps of a first-order reference run, which translates into a comparable improvement in wall-clock time. This highlights the potential of applying K-FAC to modern neural network architectures.
Paper Structure (32 sections, 8 theorems, 58 equations, 5 figures, 6 tables)

This paper contains 32 sections, 8 theorems, 58 equations, 5 figures, 6 tables.

Key Result

Proposition 1

For layer $\ell$ of a deep linear network defined as in eq:deep_linear and a Gaussian likelihood with p.d. covariance matrix ${\boldsymbol{\Sigma}}\!\in\!\mathbb{R}^{C \times C}$, K-FAC-expand is exact in the expand setting.

Figures (5)

  • Figure 1: Visualisation of K-FAC-expand and K-FAC-reduce in the expand setting. The shown quantities are for a single layer within a deep linear network. We have $N\!=\!4, R\!=\!2, P_{\ell, \mathrm{in}}\!=\!8, P_{\ell, \mathrm{out}}\!=\!8,$ and $P_\ell\!=\!P_{\ell, \mathrm{in}}\!\cdot\!P_{\ell, \mathrm{out}}\!=\!64$. As we have seen in \ref{['subsec:expand']}, K-FAC-expand is exact for the expand case in this setting and K-FAC-reduce is not. For better visibility, the color scale is not the same for all quantities, except for the approximation error (right) where black represents zero.
  • Figure 2: Visualisation of K-FAC-expand and K-FAC-reduce in the reduce setting. This is similar to \ref{['fig:visual_expand']}, but for the reduce setting, where K-FAC-reduce is exact and K-FAC-expand is not (\ref{['subsec:reduce']}).
  • Figure 3: Training results for a graph neural network on ogbg-molpcba. Both K-FAC variations require $\approx\!50\,\%$ of the steps which almost directly translates into reduced wall-clock time. The K-FAC statistics are updated every $10$ steps. We show runs with five different random seeds for each method.
  • Figure 4: Training results for a vision transformer on ImageNet. The K-FAC statistics are updated every $50$ steps. Due to amortising the costs, the reduced number of steps to the target translates into reduced wall-clock time. We show runs with three different random seeds for each method.
  • Figure 5: Training results for a vision transformer on ImageNet. This is similar to \ref{['fig:imagenet_vit']}, but the K-FAC statistics are updated every step and different hyperparameters are used. Due to K-FAC's overhead, the wall-clock time is not reduced in this setting. Moreover, the discrepancy in speed between K-FAC-expand and K-FAC-reduce becomes apparent.

Theorems & Definitions (15)

  • Proposition 1: Exactness of K-FAC-expand for deep linear network in the expand setting
  • Proposition 2: Exactness of K-FAC-reduce for deep linear network in the reduce setting
  • Lemma 3: Sufficient condition for exactness of K-FAC-expand in the expand setting
  • proof
  • Proposition 4: Exactness of K-FAC-expand for single linear layer in the expand setting
  • proof
  • Proposition 4: Exactness of K-FAC-expand for deep linear network in the expand setting
  • proof
  • Lemma 5: Sufficient condition for exactness of K-FAC-reduce in the reduce setting
  • proof
  • ...and 5 more