Table of Contents
Fetching ...

ONG: Orthogonal Natural Gradient Descent

Yajat Yadav, Patrick Mendoza, Jathin Korrapati

TL;DR

ONG tackles catastrophic forgetting in continual learning by fusing natural gradient updates with orthogonal gradient projections. It preconditions task gradients with an EKFAC-approximated inverse Fisher information and preserves previously learned directions via an orthogonal subspace, offering descent guarantees under the Fisher metric. However, empirical results on Permuted and Rotated MNIST reveal that naively combining Fisher preconditioning with Euclidean projections can degrade performance, indicating a geometry clash between the Fisher metric and standard projections. The work outlines promising future directions, including geometry-consistent projections (e.g., parallel transport) and more rigorous theoretical grounding, to enable robust, scalable continual learning.

Abstract

Orthogonal Gradient Descent (OGD) has emerged as a powerful method for continual learning. However, its Euclidean projections do not leverage the underlying information-geometric structure of the problem, which can lead to suboptimal convergence in learning tasks. To address this, we propose incorporating the natural gradient into OGD and present \textbf{ONG (Orthogonal Natural Gradient Descent)}. ONG preconditions each new task-specific gradient with an efficient EKFAC approximation of the inverse Fisher information matrix, yielding updates that follow the steepest descent direction under a Riemannian metric. To preserve performance on previously learned tasks, ONG projects these natural gradients onto the orthogonal complement of prior tasks' natural gradients. We provide an initial theoretical justification for this procedure, introduce the Orthogonal Natural Gradient Descent (ONG) algorithm, and present preliminary results on the Permuted and Rotated MNIST benchmarks. Our preliminary results, however, indicate that a naive combination of natural gradients and orthogonal projections has potential issues. This finding has motivated continued future work focused on robustly reconciling these geometric perspectives to develop a continual learning method, establishing a more rigorous theoretical foundation with formal convergence guarantees, and extending empirical validation to large-scale continual learning benchmarks.

ONG: Orthogonal Natural Gradient Descent

TL;DR

ONG tackles catastrophic forgetting in continual learning by fusing natural gradient updates with orthogonal gradient projections. It preconditions task gradients with an EKFAC-approximated inverse Fisher information and preserves previously learned directions via an orthogonal subspace, offering descent guarantees under the Fisher metric. However, empirical results on Permuted and Rotated MNIST reveal that naively combining Fisher preconditioning with Euclidean projections can degrade performance, indicating a geometry clash between the Fisher metric and standard projections. The work outlines promising future directions, including geometry-consistent projections (e.g., parallel transport) and more rigorous theoretical grounding, to enable robust, scalable continual learning.

Abstract

Orthogonal Gradient Descent (OGD) has emerged as a powerful method for continual learning. However, its Euclidean projections do not leverage the underlying information-geometric structure of the problem, which can lead to suboptimal convergence in learning tasks. To address this, we propose incorporating the natural gradient into OGD and present \textbf{ONG (Orthogonal Natural Gradient Descent)}. ONG preconditions each new task-specific gradient with an efficient EKFAC approximation of the inverse Fisher information matrix, yielding updates that follow the steepest descent direction under a Riemannian metric. To preserve performance on previously learned tasks, ONG projects these natural gradients onto the orthogonal complement of prior tasks' natural gradients. We provide an initial theoretical justification for this procedure, introduce the Orthogonal Natural Gradient Descent (ONG) algorithm, and present preliminary results on the Permuted and Rotated MNIST benchmarks. Our preliminary results, however, indicate that a naive combination of natural gradients and orthogonal projections has potential issues. This finding has motivated continued future work focused on robustly reconciling these geometric perspectives to develop a continual learning method, establishing a more rigorous theoretical foundation with formal convergence guarantees, and extending empirical validation to large-scale continual learning benchmarks.

Paper Structure

This paper contains 26 sections, 1 theorem, 9 equations, 2 figures, 4 tables, 2 algorithms.

Key Result

Lemma 4.1

Let $\boldsymbol{F}^{-1} \boldsymbol{g}$ be the natural gradient of loss function $\mathcal{L}(\boldsymbol{w})$ and $S=\{\boldsymbol{v}_1, \ldots, \boldsymbol{v}_n\}$ be an orthogonal basis. Define $\tilde{\boldsymbol{g}} = \boldsymbol{F}^{-1} \boldsymbol{g} - \sum_i^k {\mathrm{proj}_{\boldsymbol{v}

Figures (2)

  • Figure 1: Validation Accuracy for tasks 1, 5, and 10 of the Permuted MNIST dataset throughout training. The model is sequentially trained on tasks 1 through 15, and thus the x-axis is presented in terms of number of tasks.
  • Figure 2: Validation Accuracy for tasks 1, 5, and 10 of the Rotated MNIST dataset throughout training. The model is sequentially trained on tasks 1 through 15, and thus the x-axis is presented in terms of number of tasks.

Theorems & Definitions (2)

  • Lemma 4.1
  • Definition 1: Parallel transport