Table of Contents
Fetching ...

Spectral-factorized Positive-definite Curvature Learning for NN Training

Wu Lin, Felix Dangel, Runa Eschenhagen, Juhan Bae, Richard E. Turner, Roger B. Grosse

TL;DR

This work tackles the challenge of flexible, stable SPD curvature learning for neural network training, where existing non-diagonal methods suffer from costly matrix-root computations. It introduces a spectral-factorized parameterization of the preconditioner $\oxed{\mathbf{S}=\mathbf{B}\mathrm{Diag}(\mathbf{d})\mathbf{B}^T}$ and develops a Riemannian gradient-descent framework in local coordinates to enable efficient, root-free updates for arbitrary matrix roots. The approach extends to Kronecker-structured preconditioners with determinant constraints, yielding scalable updates without explicit matrix decompositions. Empirically, the method demonstrates competitive performance in SPD matrix optimization, gradient-free settings, and low-precision neural network training, while maintaining reparametrization invariance and numerical stability. This spectral-factorized, geometry-aware framework broadens the applicability of non-diagonal curvature learning to a wider range of curvature information and practical training regimes.

Abstract

Many training methods, such as Adam(W) and Shampoo, learn a positive-definite curvature matrix and apply an inverse root before preconditioning. Recently, non-diagonal training methods, such as Shampoo, have gained significant attention; however, they remain computationally inefficient and are limited to specific types of curvature information due to the costly matrix root computation via matrix decomposition. To address this, we propose a Riemannian optimization approach that dynamically adapts spectral-factorized positive-definite curvature estimates, enabling the efficient application of arbitrary matrix roots and generic curvature learning. We demonstrate the efficacy and versatility of our approach in positive-definite matrix optimization and covariance adaptation for gradient-free optimization, as well as its efficiency in curvature learning for neural net training.

Spectral-factorized Positive-definite Curvature Learning for NN Training

TL;DR

This work tackles the challenge of flexible, stable SPD curvature learning for neural network training, where existing non-diagonal methods suffer from costly matrix-root computations. It introduces a spectral-factorized parameterization of the preconditioner and develops a Riemannian gradient-descent framework in local coordinates to enable efficient, root-free updates for arbitrary matrix roots. The approach extends to Kronecker-structured preconditioners with determinant constraints, yielding scalable updates without explicit matrix decompositions. Empirically, the method demonstrates competitive performance in SPD matrix optimization, gradient-free settings, and low-precision neural network training, while maintaining reparametrization invariance and numerical stability. This spectral-factorized, geometry-aware framework broadens the applicability of non-diagonal curvature learning to a wider range of curvature information and practical training regimes.

Abstract

Many training methods, such as Adam(W) and Shampoo, learn a positive-definite curvature matrix and apply an inverse root before preconditioning. Recently, non-diagonal training methods, such as Shampoo, have gained significant attention; however, they remain computationally inefficient and are limited to specific types of curvature information due to the costly matrix root computation via matrix decomposition. To address this, we propose a Riemannian optimization approach that dynamically adapts spectral-factorized positive-definite curvature estimates, enabling the efficient application of arbitrary matrix roots and generic curvature learning. We demonstrate the efficacy and versatility of our approach in positive-definite matrix optimization and covariance adaptation for gradient-free optimization, as well as its efficiency in curvature learning for neural net training.

Paper Structure

This paper contains 35 sections, 48 equations, 12 figures.

Figures (12)

  • Figure 1: Adaptive update schemes for full-matrix and Kronecker-based spectral factorizations. Full-matrix scheme:$\mathrm{Tril}(\hbox{$\hbox{$\mathbf{U}$}$})$ is a lower-triangular matrix with the $(i,j)$-th entry $[\hbox{$\hbox{$\mathbf{U}$}$}]_{ij}:= -[\hbox{$\hbox{$\mathbf{B}$}$}^T \hbox{$\hbox{$\mathbf{g}$}$}\hbox{$\hbox{$\mathbf{g}$}$}^T \hbox{$\hbox{$\mathbf{B}$}$}]_{ij}/(d_i - d_j)$ when $d_i \neq d_j$ and $0$ otherwise. Kronecker-based scheme: We assume that NN weights take a matrix form:$\hbox{$\hbox{$\mathbf{M}$}$}:=\mathrm{Mat}(\hbox{$\hbox{$\boldsymbol{\mu}$}$}) \in \hbox{$\mathbb{R}$}^{n \times m}$, where $\mathrm{Mat}(\cdot)$ is a matrix representation of vector $\hbox{$\hbox{$\boldsymbol{\mu}$}$}$. We define $\hbox{$\hbox{$\mathbf{W}$}$}^{(K)} :=(\hbox{$\hbox{$\mathbf{B}$}$}^{(K)})^T \hbox{$\hbox{$\mathbf{G}$}$}^T { (\hbox{$\hbox{$\mathbf{S}$}$}^{(C)})^{-1} } \hbox{$\hbox{$\mathbf{G}$}$} \hbox{$\hbox{$\mathbf{B}$}$}^{(K)}$, $\hbox{$\hbox{$\mathbf{W}$}$}^{(C)} :=(\hbox{$\hbox{$\mathbf{B}$}$}^{(C)})^T \hbox{$\hbox{$\mathbf{G}$}$} { (\hbox{$\hbox{$\mathbf{S}$}$}^{(K)})^{-1} } \hbox{$\hbox{$\mathbf{G}$}$}^T \hbox{$\hbox{$\mathbf{B}$}$}^{(C)}$, $k^{(K)}:=n$ and $k^{(C)}:=m$. where $\hbox{$\hbox{$\mathbf{S}$}$}^{(l)} := \hbox{$\hbox{$\mathbf{B}$}$}^{(l)} \mathrm{Diag}(\hbox{$\hbox{$\mathbf{d}$}$}^{(l)}) (\hbox{$\hbox{$\mathbf{B}$}$}^{(l)})^T$ for $l \in \{C,K\}$. We define a lower-triangular matrix $\mathrm{Tril}(\hbox{$\hbox{$\mathbf{U}$}$}^{(l)})$ with its $(i,j)$-th entry $[{\hbox{$\hbox{$\mathbf{U}$}$}^{(l)} }]_{ij}:= -[W^{(l)}]_{ij}/(d^{(l)}_i - d^{(l)}_j)$ if $d^{(l)}_i\neq d^{(l)}_j$ and $0$ otherwise, where $d^{(l)}_i$ denotes the $i$-th entry of vector $\hbox{$\hbox{$\mathbf{d}$}$}^{(l)}$. For numerical stability, we set $[\hbox{$\hbox{$\mathbf{U}$}$}^{(l)}]_{ij}=0$ if $|d^{(l)}_i - d^{(l)}_j|$ is near $0$. See Fig \ref{['fig:kron_updates']} in Appx. \ref{['app:connections']} for a simplified version.
  • Figure 2: Empirical validation of our update schemes for SPD curvature learning. Full-matrix Scheme: The first plot on the left shows that the scheme converges to a fixed-point solution as fast as the default scheme in Eq. \ref{['eq:root_free']} (with $\gamma=1$) to update $\hbox{$\hbox{$\mathbf{S}$}$}\in \hbox{$\mathbb{R}$}^{100 \times 100}$ and the Cholesky-based scheme. The second plot illustrates how closely our scheme matches the iterates generated by the default update scheme at each iteration. Kronecker-based Scheme: The third plot shows that our update scheme gives a structural approximation $\hbox{$\hbox{$\mathbf{S}$}$}^{(C)} \otimes \hbox{$\hbox{$\mathbf{S}$}$}^{(K)}$ of a fixed-point solution obtained by the default full-matrix update scheme for $\hbox{$\hbox{$\mathbf{S}$}$} \in \mathcal{R}^{99 \times 99}$, where $\hbox{$\hbox{$\mathbf{S}$}$}^{(C)} \in \hbox{$\mathbb{R}$}^{9 \times 9}$ and $\hbox{$\hbox{$\mathbf{S}$}$}^{(K)} \in \hbox{$\mathbb{R}$}^{11 \times 11}$. Our scheme converges as fast as Kronecker-structured baseline methods, including the impractical projection-based method. The last plot illustrates how closely our scheme matches the unstructured iterates generated by the default one at each iteration. See Figs. \ref{['fig:full_mat_toy']}-\ref{['fig:kron_toy']} for more results.
  • Figure 3: Experiments showcase the efficacy and versatility of our approach for generic curvature learning. Our update scheme matches the equivalent Riemannian baselines, empirically illustrating the reparametrization invariance. SPD Matrix Optimization: The first two plots on the left show the performance of our full-matrix update scheme for learning SPD matrices. Our update scheme matches the baselines, as our scheme is RGD in local coordinates. Gradient-free Optimization: The last three plots show the performance of our scheme for gradient-free optimization problems on Ackley (multimode), Rosenbrok (flat valley), and Griewank (multimode) functions. See Fig. \ref{['fig:extra_results']} in the appendix for more results.
  • Figure 4: Experiments demonstrate the efficiency of our update schemes for low-precision NN training. The plots show the performance of our Kronecker-based scheme for training vision transformers with half precision. All models are trained for 210 epochs, including 10 epochs for warmup. For SOAP and our method, we update their preconditioners every two iterations. SOAP performs much slower than the other methods because it has to run in single precision to use matrix decomposition. Using a different matrix root can affect the performance. Our method not only matches Muon's performance but also opens the door to using curvature information and matrix roots beyond Muon. See Figs. \ref{['fig:imagewoof']} and \ref{['fig:imagenet25']} in Appx. \ref{['app:extra_nn_training']} for a comparison of the methods based on iteration efficiency and wall-clock time.
  • Figure 5: Comparison between the original full-matrix update scheme (e.g., full-matrix RMSprop when $\gamma=1$) and our update scheme when the exponential map is truncated. We can see that our update scheme is a decomposition-free version of the adaptive method. Here, we can use the first-order truncation of the exponential map (see the top box of Fig. \ref{['fig:kronecker']} for the map). This is possible because we leverage the positive semi-definiteness of the GOP. Thanks to the positive semi-definitness, the update of $\hbox{$\hbox{$\mathbf{d}$}$}$ is always non-negative. In practice, we introduce a damping term $\lambda$ so that $\hbox{$\hbox{$\mathbf{d}$}$}$ is always positive.
  • ...and 7 more figures

Theorems & Definitions (16)

  • Claim 1
  • Claim 2
  • Claim 3
  • Claim 4
  • Claim 5
  • Claim 6
  • Claim 7
  • proof
  • Claim 8
  • proof
  • ...and 6 more