Table of Contents
Fetching ...

Linear Mode Connectivity in Multitask and Continual Learning

Seyed Iman Mirzadeh, Mehrdad Farajtabar, Dilan Gorur, Razvan Pascanu, Hassan Ghasemzadeh

TL;DR

The paper investigates why multitask learning (MTL) often outperforms continual learning (CL) by examining the loss landscape. It finds that, when MT and CL start from the same initialization, their minima are connected by a linear, low-loss path, suggesting that staying in low-curvature directions helps mitigate forgetting. Building on this, the authors introduce MC-SGD, a connectivity-aware CL algorithm that regularizes learning along lines toward prior minima and uses a replay buffer to approximate past tasks. MC-SGD demonstrates superior performance over several baselines on standard vision benchmarks and reveals that its minima remain nearly linearly connected to previous solutions, providing both theoretical insight and practical gains for continual learning.

Abstract

Continual (sequential) training and multitask (simultaneous) training are often attempting to solve the same overall objective: to find a solution that performs well on all considered tasks. The main difference is in the training regimes, where continual learning can only have access to one task at a time, which for neural networks typically leads to catastrophic forgetting. That is, the solution found for a subsequent task does not perform well on the previous ones anymore. However, the relationship between the different minima that the two training regimes arrive at is not well understood. What sets them apart? Is there a local structure that could explain the difference in performance achieved by the two different schemes? Motivated by recent work showing that different minima of the same task are typically connected by very simple curves of low error, we investigate whether multitask and continual solutions are similarly connected. We empirically find that indeed such connectivity can be reliably achieved and, more interestingly, it can be done by a linear path, conditioned on having the same initialization for both. We thoroughly analyze this observation and discuss its significance for the continual learning process. Furthermore, we exploit this finding to propose an effective algorithm that constrains the sequentially learned minima to behave as the multitask solution. We show that our method outperforms several state of the art continual learning algorithms on various vision benchmarks.

Linear Mode Connectivity in Multitask and Continual Learning

TL;DR

The paper investigates why multitask learning (MTL) often outperforms continual learning (CL) by examining the loss landscape. It finds that, when MT and CL start from the same initialization, their minima are connected by a linear, low-loss path, suggesting that staying in low-curvature directions helps mitigate forgetting. Building on this, the authors introduce MC-SGD, a connectivity-aware CL algorithm that regularizes learning along lines toward prior minima and uses a replay buffer to approximate past tasks. MC-SGD demonstrates superior performance over several baselines on standard vision benchmarks and reveals that its minima remain nearly linearly connected to previous solutions, providing both theoretical insight and practical gains for continual learning.

Abstract

Continual (sequential) training and multitask (simultaneous) training are often attempting to solve the same overall objective: to find a solution that performs well on all considered tasks. The main difference is in the training regimes, where continual learning can only have access to one task at a time, which for neural networks typically leads to catastrophic forgetting. That is, the solution found for a subsequent task does not perform well on the previous ones anymore. However, the relationship between the different minima that the two training regimes arrive at is not well understood. What sets them apart? Is there a local structure that could explain the difference in performance achieved by the two different schemes? Motivated by recent work showing that different minima of the same task are typically connected by very simple curves of low error, we investigate whether multitask and continual solutions are similarly connected. We empirically find that indeed such connectivity can be reliably achieved and, more interestingly, it can be done by a linear path, conditioned on having the same initialization for both. We thoroughly analyze this observation and discuss its significance for the continual learning process. Furthermore, we exploit this finding to propose an effective algorithm that constrains the sequentially learned minima to behave as the multitask solution. We show that our method outperforms several state of the art continual learning algorithms on various vision benchmarks.

Paper Structure

This paper contains 27 sections, 5 equations, 18 figures, 1 table.

Figures (18)

  • Figure 1: Left: Depiction of the training regime considered. First $\hat{w}_{1}$ is learned on task 1. Afterwards we either reach $\hat{w}_{2}$ by learning second task or $w^*_{2}$ by training on both tasks simultaneously. Right: Depiction of linear connectivity between $w^*_{2}$ and $\hat{w}_{1}$ and between $w^*_{2}$ and $\hat{w}_{2}$.
  • Figure 2: Continual and Multitask learning performance and relation between minima. Top row: Rotation MNIST. Bottom row: Split CIFAR-100. Left column: Accuracy of all tasks during continual training. Middle: Euclidean distance. Right: CKA distance. Note that $\hat{w}_5$ is not a good solution for task 1 although it's closer (more similar) to $\hat{w}_1$ than $w^*_5$ in distance. Similarly, $\hat{w}_5$ is closer (more similar) to $\hat{w}_1$ than $w^*_5$ in terms of CKA distance. Therefore, neither Euclidean nor CKA distance is able to realize MTL is better in avoiding catastrophic forgetting.
  • Figure 3: Cross-entropy validation loss surface of on rotation MNIST (a and b), and split CIFAR-100 (c and d), as a function of weights in a two-dimensional subspace passing through $\hat{w}_{1}$, $\hat{w}_{2}$, and $w^*_{2}$.
  • Figure 4: Exploring the loss along the linear paths connecting the different solutions: The loss increases on the interpolation line between the first task solution($\hat{w}_{1}$) and subsequent continual solutions, while the loss remains low on the interpolation line between $\hat{w}_{1}$ and subsequent multitask minima (a and c). The same observation also holds for the second task solution ($\hat{w}_{2}$) (b and d)
  • Figure 5: Comparison of the eigenspectrum of the Hessian matrix for $\hat{w}_1$. (a): top Eigenvalues. (b and c): The overlap between Hessian eigenvectors and directions from $\hat{w}_{1}$ to $\hat{w}_{2}$ and $w^*_{2}$.
  • ...and 13 more figures