Table of Contents
Fetching ...

Orthogonal Gradient Descent for Continual Learning

Mehrdad Farajtabar, Navid Azizan, Alex Mott, Ang Li

TL;DR

The paper tackles catastrophic forgetting in continual learning by introducing Orthogonal Gradient Descent (OGD), which preserves previously learned knowledge by projecting gradients onto the orthogonal complement of a subspace spanned by previous model-output gradients. This approach avoids storing raw past data and leverages the network's high capacity to learn new tasks with minimal interference. Through experiments on Permuted Mnist, Rotated Mnist, and Split Mnist, OGD demonstrates competitive performance against established baselines like EWC and A-GEM, approaching a theoretical multi-task upper bound in some settings. The work highlights practical memory considerations and potential extensions to other optimizers and higher-order information, offering a robust parameter-space perspective on continual learning.

Abstract

Neural networks are achieving state of the art and sometimes super-human performance on learning tasks across a variety of domains. Whenever these problems require learning in a continual or sequential manner, however, neural networks suffer from the problem of catastrophic forgetting; they forget how to solve previous tasks after being trained on a new task, despite having the essential capacity to solve both tasks if they were trained on both simultaneously. In this paper, we propose to address this issue from a parameter space perspective and study an approach to restrict the direction of the gradient updates to avoid forgetting previously-learned data. We present the Orthogonal Gradient Descent (OGD) method, which accomplishes this goal by projecting the gradients from new tasks onto a subspace in which the neural network output on previous task does not change and the projected gradient is still in a useful direction for learning the new task. Our approach utilizes the high capacity of a neural network more efficiently and does not require storing the previously learned data that might raise privacy concerns. Experiments on common benchmarks reveal the effectiveness of the proposed OGD method.

Orthogonal Gradient Descent for Continual Learning

TL;DR

The paper tackles catastrophic forgetting in continual learning by introducing Orthogonal Gradient Descent (OGD), which preserves previously learned knowledge by projecting gradients onto the orthogonal complement of a subspace spanned by previous model-output gradients. This approach avoids storing raw past data and leverages the network's high capacity to learn new tasks with minimal interference. Through experiments on Permuted Mnist, Rotated Mnist, and Split Mnist, OGD demonstrates competitive performance against established baselines like EWC and A-GEM, approaching a theoretical multi-task upper bound in some settings. The work highlights practical memory considerations and potential extensions to other optimizers and higher-order information, offering a robust parameter-space perspective on continual learning.

Abstract

Neural networks are achieving state of the art and sometimes super-human performance on learning tasks across a variety of domains. Whenever these problems require learning in a continual or sequential manner, however, neural networks suffer from the problem of catastrophic forgetting; they forget how to solve previous tasks after being trained on a new task, despite having the essential capacity to solve both tasks if they were trained on both simultaneously. In this paper, we propose to address this issue from a parameter space perspective and study an approach to restrict the direction of the gradient updates to avoid forgetting previously-learned data. We present the Orthogonal Gradient Descent (OGD) method, which accomplishes this goal by projecting the gradients from new tasks onto a subspace in which the neural network output on previous task does not change and the projected gradient is still in a useful direction for learning the new task. Our approach utilizes the high capacity of a neural network more efficiently and does not require storing the previously learned data that might raise privacy concerns. Experiments on common benchmarks reveal the effectiveness of the proposed OGD method.

Paper Structure

This paper contains 13 sections, 1 theorem, 12 equations, 5 figures, 7 tables, 1 algorithm.

Key Result

Lemma 3.1

Let $g$ be the gradient of loss function $L(w)$ and $S=\{v_1, \ldots, v_n\}$ is the orthogonal basis. Let $\tilde{g} = g - \sum_i^k {\mathrm{proj}_{v_i}(g)}$. Then, $-\tilde{g}$ is also a descent direction for $L(w)$.

Figures (5)

  • Figure 1: An illustration of how Orthogonal Gradient Descent corrects the directions of the gradients. $g$ is the original gradient computed for task B and $\tilde{g}$ is the projection of $g$ onto the orthogonal space w.r.t the gradient $\nabla f_j(x;w_A^*)$ computed at task A. Moving within this (blue) space allows the model parameters to get closer to the low error (green) region for both tasks.
  • Figure 2: Performance of different methods on permuted Mnist task. 3 different permutations ($p_1$, $p_2$, and $p_3$) are used and the model is trained to classify Mnist digits under permutation $p_1$ for 5 epochs, then under $p_2$ for 5 epochs and then under $p_3$ for 5 epochs. The vertical dashed lines represent the points in the training where the permutations switch. The top plot reports the accuracy of the model on batches of the Mnist test set under $p_1$; the middle plot, under $p_2$; and the bottom plot under $p_3$. The y-axis is truncated to show the details. Note that MTL represents a setting where the model is directly trained on all previous tasks. Because we keep constant batch size and number of epochs, the MTL method effectively sees one third of the task 3 data that other methods do. This is the reason that MTL learns slower on task 3 than other methods.
  • Figure 3: Rotated Mnist: Accuracies of multiple continual learning methods. Every classifier is trained for 5 epochs on standard Mnist and then trained for another 5 epochs on a variant of Mnist whose images are rotated by the specified angle. The accuracy is computed over the entire original (un-rotated) Mnist test set after the model being trained on the rotated dataset. Each bar represents the mean accuracy over 10 independent runs and the error bars reflect their standard deviations. MTL represents the (non-continual) multi-task learning setting where the model is trained with the combined data from all previous tasks.
  • Figure 4: Split Mnist: Accuracies of multiple continual learning methods. The training regime is the same as that of Figure \ref{['fig:permuted_mnist_training']}. The reported value is the accuracy on task 1 after the model being trained on task 2. Different plots correspond to different configurations, i.e., different partitions of the Mnist labels into task 1 and task 2.
  • Figure 5: The performance of OGD versus others as a function of the number of training epochs for each task on permuted Mnist.

Theorems & Definitions (2)

  • Lemma 3.1
  • proof