Spectral-factorized Positive-definite Curvature Learning for NN Training
Wu Lin, Felix Dangel, Runa Eschenhagen, Juhan Bae, Richard E. Turner, Roger B. Grosse
TL;DR
This work tackles the challenge of flexible, stable SPD curvature learning for neural network training, where existing non-diagonal methods suffer from costly matrix-root computations. It introduces a spectral-factorized parameterization of the preconditioner $\oxed{\mathbf{S}=\mathbf{B}\mathrm{Diag}(\mathbf{d})\mathbf{B}^T}$ and develops a Riemannian gradient-descent framework in local coordinates to enable efficient, root-free updates for arbitrary matrix roots. The approach extends to Kronecker-structured preconditioners with determinant constraints, yielding scalable updates without explicit matrix decompositions. Empirically, the method demonstrates competitive performance in SPD matrix optimization, gradient-free settings, and low-precision neural network training, while maintaining reparametrization invariance and numerical stability. This spectral-factorized, geometry-aware framework broadens the applicability of non-diagonal curvature learning to a wider range of curvature information and practical training regimes.
Abstract
Many training methods, such as Adam(W) and Shampoo, learn a positive-definite curvature matrix and apply an inverse root before preconditioning. Recently, non-diagonal training methods, such as Shampoo, have gained significant attention; however, they remain computationally inefficient and are limited to specific types of curvature information due to the costly matrix root computation via matrix decomposition. To address this, we propose a Riemannian optimization approach that dynamically adapts spectral-factorized positive-definite curvature estimates, enabling the efficient application of arbitrary matrix roots and generic curvature learning. We demonstrate the efficacy and versatility of our approach in positive-definite matrix optimization and covariance adaptation for gradient-free optimization, as well as its efficiency in curvature learning for neural net training.
