Table of Contents
Fetching ...

Is All Learning (Natural) Gradient Descent?

Lucas Shoji, Kenta Suzuki, Leo Kozachkov

TL;DR

The paper shows that a broad class of effective learning rules that improve a scalar performance measure can be recast as natural gradient descent with a symmetric positive definite metric $M(\theta,t)$. It derives a canonical metric form $M = \frac{1}{y^T g} y y^T + \sum_{i=1}^{D-1} u_i u_i^T$, and identifies an optimal choice $M_{opt}$ within a one-parameter family that minimizes the condition number, with eigenstructure governed by the angle $\psi$ between the update direction and the negative gradient. The theory extends to continuous-time, discrete-time, stochastic, and time-varying losses, and is demonstrated via applications to a stable linear time-invariant system and biologically plausible feedback-alignment learning. This work provides a unifying geometric lens for learning rules, suggesting that gradient-based optimization under an appropriate metric underpins diverse learning processes and offering practical metrics for improving optimization efficiency and stability.

Abstract

This paper shows that a wide class of effective learning rules -- those that improve a scalar performance measure over a given time window -- can be rewritten as natural gradient descent with respect to a suitably defined loss function and metric. Specifically, we show that parameter updates within this class of learning rules can be expressed as the product of a symmetric positive definite matrix (i.e., a metric) and the negative gradient of a loss function. We also demonstrate that these metrics have a canonical form and identify several optimal ones, including the metric that achieves the minimum possible condition number. The proofs of the main results are straightforward, relying only on elementary linear algebra and calculus, and are applicable to continuous-time, discrete-time, stochastic, and higher-order learning rules, as well as loss functions that explicitly depend on time.

Is All Learning (Natural) Gradient Descent?

TL;DR

The paper shows that a broad class of effective learning rules that improve a scalar performance measure can be recast as natural gradient descent with a symmetric positive definite metric . It derives a canonical metric form , and identifies an optimal choice within a one-parameter family that minimizes the condition number, with eigenstructure governed by the angle between the update direction and the negative gradient. The theory extends to continuous-time, discrete-time, stochastic, and time-varying losses, and is demonstrated via applications to a stable linear time-invariant system and biologically plausible feedback-alignment learning. This work provides a unifying geometric lens for learning rules, suggesting that gradient-based optimization under an appropriate metric underpins diverse learning processes and offering practical metrics for improving optimization efficiency and stability.

Abstract

This paper shows that a wide class of effective learning rules -- those that improve a scalar performance measure over a given time window -- can be rewritten as natural gradient descent with respect to a suitably defined loss function and metric. Specifically, we show that parameter updates within this class of learning rules can be expressed as the product of a symmetric positive definite matrix (i.e., a metric) and the negative gradient of a loss function. We also demonstrate that these metrics have a canonical form and identify several optimal ones, including the metric that achieves the minimum possible condition number. The proofs of the main results are straightforward, relying only on elementary linear algebra and calculus, and are applicable to continuous-time, discrete-time, stochastic, and higher-order learning rules, as well as loss functions that explicitly depend on time.
Paper Structure (28 sections, 7 theorems, 81 equations, 3 figures)

This paper contains 28 sections, 7 theorems, 81 equations, 3 figures.

Key Result

Theorem 1

Suppose that $\mathcal{L} \colon \mathbb{R}^D \rightarrow \mathbb{R}$ is a twice continuously differentiable function, and that $p \in \mathbb{R}^D$. Then there exists some $\lambda \in (0,1)$ such that

Figures (3)

  • Figure 1: A) Contour lines of a loss function (darker colors = lower loss). Parameters update in the direction of $g$. If this update decreases the loss, and if the step-size is small, $g$ is equivalent to steepest descent with a non-Euclidean metric, $M(\theta)$. In this case, the angle $\psi$ between $g$ and the negative gradient is acute. Ellipse: $\epsilon$-ball in this metric. B) Steepest descent with the Euclidean metric. Circle: $\epsilon$-ball in this metric.
  • Figure 2: Natural gradient descent minimizes a loss function (dashed contours) by evolving the parameters $\theta$ in the direction of steepest descent in a non-Euclidean space. This space, a $D$-dimensional manifold with metric $M(\theta, t)$, is visualized as a surface embedded in a higher dimensional Euclidean space. We demonstrate that a wide class of learning rules that decreases the loss function (not necessarily monotonically) fits this framework. In this context, the dynamics of both $\theta$ and $M$ are determined by the learning rule and the loss function.
  • Figure 3: A) Eigenvalues of the optimal metric $M_{\text{opt}}$ as a function of the angle $\psi$ between vectors $y$ and $g$, with the norm ratio $\|y\|/\|g\|$ fixed at unity. Refer to Eq. \ref{['eq:lambdas']} in the main text. B) Spectrum of $M_\text{opt}$ for stable linear time-invariant dynamics over time. C) Lyapunov function (loss) corresponding to the dynamics in (B), demonstrating a monotonic decrease. D) Spectrum of $M_\text{opt}$ for a small multi-layer network trained with a biologically plausible learning rule (feedback alignment) to classify MNIST digits. E) Training loss of feedback alignment as a function of training steps, showing that while the instantaneous loss is not strictly monotonic, the average loss decreases over time.

Theorems & Definitions (15)

  • proof
  • Theorem 1: Taylor's Theorem
  • Definition 1: Discrete Gradient
  • Lemma 1
  • proof
  • Lemma 2
  • proof
  • Lemma 3
  • proof
  • Proposition 1
  • ...and 5 more