Table of Contents
Fetching ...

Quadratic models for understanding catapult dynamics of neural networks

Libin Zhu, Chaoyue Liu, Adityanarayanan Radhakrishnan, Mikhail Belkin

TL;DR

This paper introduces Neural Quadratic Models (NQMs) as a second-order Taylor approximation of neural networks to study optimization and generalization beyond the infinite-width linear regime. It proves that NQMs can exhibit catapult dynamics under large learning rates, derives detailed single- and multi-example dynamics, and connects these behaviors to broader general quadratic models and wide networks. Empirically, NQMs mirror neural networks in generalization improvements observed in the catapult regime, across various architectures and datasets, and outperform the linear NTK benchmark in this regime. The work suggests that quadratic models provide a tractable, informative lens for understanding finite-width neural networks and motivates future exploration of their induced kernels and potential for representation learning.

Abstract

While neural networks can be approximated by linear models as their width increases, certain properties of wide neural networks cannot be captured by linear models. In this work we show that recently proposed Neural Quadratic Models can exhibit the "catapult phase" [Lewkowycz et al. 2020] that arises when training such models with large learning rates. We then empirically show that the behaviour of neural quadratic models parallels that of neural networks in generalization, especially in the catapult phase regime. Our analysis further demonstrates that quadratic models can be an effective tool for analysis of neural networks.

Quadratic models for understanding catapult dynamics of neural networks

TL;DR

This paper introduces Neural Quadratic Models (NQMs) as a second-order Taylor approximation of neural networks to study optimization and generalization beyond the infinite-width linear regime. It proves that NQMs can exhibit catapult dynamics under large learning rates, derives detailed single- and multi-example dynamics, and connects these behaviors to broader general quadratic models and wide networks. Empirically, NQMs mirror neural networks in generalization improvements observed in the catapult regime, across various architectures and datasets, and outperform the linear NTK benchmark in this regime. The work suggests that quadratic models provide a tractable, informative lens for understanding finite-width neural networks and motivates future exploration of their induced kernels and potential for representation learning.

Abstract

While neural networks can be approximated by linear models as their width increases, certain properties of wide neural networks cannot be captured by linear models. In this work we show that recently proposed Neural Quadratic Models can exhibit the "catapult phase" [Lewkowycz et al. 2020] that arises when training such models with large learning rates. We then empirically show that the behaviour of neural quadratic models parallels that of neural networks in generalization, especially in the catapult phase regime. Our analysis further demonstrates that quadratic models can be an effective tool for analysis of neural networks.
Paper Structure (64 sections, 17 theorems, 126 equations, 17 figures)

This paper contains 64 sections, 17 theorems, 126 equations, 17 figures.

Key Result

Theorem 1

Consider training the NQM Eq. (eq:nn_quad_relu) with squared loss on a single training example by GD. With a super-critical learning rate $\eta \in \left[\frac{2+ \epsilon}{\lambda_0}, \frac{4-\epsilon}{\lambda_0} \right]$ where $\epsilon = \Theta\left(\frac{\log m}{\sqrt{m}}\right)$, the catapult

Figures (17)

  • Figure 1: Optimization dynamics for linear and non-linear models based on choice of learning rate. (a) Linear models either converge monotonically if learning rate is less than ${\eta_{\mathrm{crit}}}$ and diverge otherwise. (b) Unlike linear models, finitely wide neural networks and NQMs Eq. (\ref{['eq:nn_quad']}) (or general quadratic models Eq. (\ref{['eq:quadratic']})) can additionally observe a catapult phase when ${\eta_{\mathrm{crit}}} < \eta <{\eta_{\mathrm{max}}}$.
  • Figure 2: (a) Optimization dynamics of wide neural networks with sub-critical and super-critical learning rates. With sub-critical learning rates ($0<\eta<{\eta_{\mathrm{crit}}})$, the tangent kernel of wide neural networks is nearly constant during training, and the loss decreases monotonically. The whole optimization path is contained in the ball $B({\mathbf{w}}_0,R):=\{{\mathbf{w}}: \|{\mathbf{w}} - {\mathbf{w}}_0\|\leq R\}$ with a finite radius $R$. With super-critical learning rates (${\eta_{\mathrm{crit}}}<\eta<{\eta_{\mathrm{max}}})$, the catapult phase happens: the loss first increases and then decreases, along with a decrease of the norm of the tangent kernel . The optimization path goes beyond the finite radius ball. (b) Test loss of $f_{{\mathrm{quad}}}$, $f$ and $f_{{{\mathrm{lin}}}}$ plotted against different learning rates. With sub-critical learning rates, all three models have nearly identical test loss for any sub-critical learning rate. With super-critical learning rates, $f$ and $f_{{\mathrm{quad}}}$ have smaller best test loss than the one with sub-critical learning rates. Experimental details are in Appendix \ref{['subsec:exp_quad_setting']}.
  • Figure 3: Training dynamics of NQMs for multiple examples case with different learning rates. By our analysis, two critical values are $2/\lambda_1(0) = 0.37$ and $2/\lambda_2(0) = 0.39$. When $\eta<0.37$, linear dynamics dominate hence the kernel is nearly constant; when $0.37<\eta<0.39$, the catapult phase happens in ${\boldsymbol{p}}_1$ and only $\lambda_1(t)$ decreases; when $0.39<\eta<{\eta_{\mathrm{max}}}$, the catapult phase happens in ${\boldsymbol{p}}_1$ and ${\boldsymbol{p}}_2$ hence both $\lambda_1(t)$ and $\lambda_2(t)$ decreases. The experiment details can be found in Appendix \ref{['subsec:multi_quad']}.
  • Figure 4: Best test loss plotted against different learning rates for $f({\mathbf{w}})$, $f_{{{\mathrm{lin}}}}({\mathbf{w}})$ and $f_{{\mathrm{quad}}}({\mathbf{w}})$ across a variety of datasets and network architectures.
  • Figure 5: Training dynamics of wide neural networks for multiple examples case with different learning rates. Compared to the training dynamics of NQMs, i.e., Figure \ref{['fig:multi_quad']}, the behaviour of of top eigenvalues is almost the same with different learning rates: when $\eta<0.37$, the kernel is nearly constant; when $0.37<\eta<0.39$, only $\lambda_1(t)$ decreases; when $0.39<\eta<{\eta_{\mathrm{max}}}$, both $\lambda_1(t)$ and $\lambda_2(t)$ decreases. See the experiment setting in Appendix \ref{['subsec:multi_nn']}.
  • ...and 12 more figures

Theorems & Definitions (33)

  • Definition 1: Tangent Kernel
  • Definition 2: Critical learning rate
  • Theorem 1: Catapult dynamics on a single training example
  • proof : Proof of Theorem \ref{['thm:single']}
  • Lemma 1
  • proof
  • Lemma 2
  • proof
  • Theorem 2
  • Theorem 3: Catapult dynamics on multiple training examples
  • ...and 23 more