Table of Contents
Fetching ...

Order parameters and phase transitions of continual learning in deep neural networks

Haozhe Shan, Qianyi Li, Haim Sompolinsky

TL;DR

This work develops a Gibbs-based statistical framework to understand continual learning in deep neural networks, linking forgetting and interference to task relations and architecture through scalar order parameters. By deriving single-head and multi-head theories with generalized, time-dependent kernel functions, the authors predict when forgetting occurs and reveal phase transitions governed by task similarity and network load. The results show depth and width can mitigate forgetting, while multi-head readouts introduce rich phase behavior, including a catastrophic anterograde regime under certain conditions. These insights offer quantitative guidelines for designing neural systems and have implications for understanding how the brain achieves robust continual learning. The framework integrates analytic theory with numerical simulations across diverse task sequences, providing a principled lens for evaluating CL strategies and architectures.

Abstract

Continual learning (CL) enables animals to learn new tasks without erasing prior knowledge. CL in artificial neural networks (NNs) is challenging due to catastrophic forgetting, where new learning degrades performance on older tasks. While various techniques exist to mitigate forgetting, theoretical insights into when and why CL fails in NNs are lacking. Here, we present a statistical-mechanics theory of CL in deep, wide NNs, which characterizes the network's input-output mapping as it learns a sequence of tasks. It gives rise to order parameters (OPs) that capture how task relations and network architecture influence forgetting and anterograde interference, as verified by numerical evaluations. For networks with a shared readout for all tasks (single-head CL), the relevant-feature and rule similarity between tasks, respectively measured by two OPs, are sufficient to predict a wide range of CL behaviors. In addition, the theory predicts that increasing the network depth can effectively reduce interference between tasks, thereby lowering forgetting. For networks with task-specific readouts (multi-head CL), the theory identifies a phase transition where CL performance shifts dramatically as tasks become less similar, as measured by another task-similarity OP. While forgetting is relatively mild compared to single-head CL across all tasks, sufficiently low similarity leads to catastrophic anterograde interference, where the network retains old tasks perfectly but completely fails to generalize new learning. Our results delineate important factors affecting CL performance and suggest strategies for mitigating forgetting.

Order parameters and phase transitions of continual learning in deep neural networks

TL;DR

This work develops a Gibbs-based statistical framework to understand continual learning in deep neural networks, linking forgetting and interference to task relations and architecture through scalar order parameters. By deriving single-head and multi-head theories with generalized, time-dependent kernel functions, the authors predict when forgetting occurs and reveal phase transitions governed by task similarity and network load. The results show depth and width can mitigate forgetting, while multi-head readouts introduce rich phase behavior, including a catastrophic anterograde regime under certain conditions. These insights offer quantitative guidelines for designing neural systems and have implications for understanding how the brain achieves robust continual learning. The framework integrates analytic theory with numerical simulations across diverse task sequences, providing a principled lens for evaluating CL strategies and architectures.

Abstract

Continual learning (CL) enables animals to learn new tasks without erasing prior knowledge. CL in artificial neural networks (NNs) is challenging due to catastrophic forgetting, where new learning degrades performance on older tasks. While various techniques exist to mitigate forgetting, theoretical insights into when and why CL fails in NNs are lacking. Here, we present a statistical-mechanics theory of CL in deep, wide NNs, which characterizes the network's input-output mapping as it learns a sequence of tasks. It gives rise to order parameters (OPs) that capture how task relations and network architecture influence forgetting and anterograde interference, as verified by numerical evaluations. For networks with a shared readout for all tasks (single-head CL), the relevant-feature and rule similarity between tasks, respectively measured by two OPs, are sufficient to predict a wide range of CL behaviors. In addition, the theory predicts that increasing the network depth can effectively reduce interference between tasks, thereby lowering forgetting. For networks with task-specific readouts (multi-head CL), the theory identifies a phase transition where CL performance shifts dramatically as tasks become less similar, as measured by another task-similarity OP. While forgetting is relatively mild compared to single-head CL across all tasks, sufficiently low similarity leads to catastrophic anterograde interference, where the network retains old tasks perfectly but completely fails to generalize new learning. Our results delineate important factors affecting CL performance and suggest strategies for mitigating forgetting.
Paper Structure (74 sections, 171 equations, 15 figures, 1 table)

This paper contains 74 sections, 171 equations, 15 figures, 1 table.

Figures (15)

  • Figure 1: Different types of task relations in CL and the weight space schematics of the Gibbs formulation. a Hypothetical odor-mixture classification tasks demonstrating various possible relations between two supervised-learning tasks. Each task requires classification of two odor mixtures (each column is a mixture; black/white squares indicate present/absent odors). The subject needs to respond "L" or "R" depending on the presented mixture. The first task is A; possible second tasks B1, B2, B3 and B4 exhibit different relations to A. Tasks can have highly overlapping odors (A vs. B1, A vs. B2), partially overlapping odors (A vs. B4) or entirely different odors (A vs. B3). Among these odors some are relevant for the task (e.g. odor 3 for task A) while others are not (e.g. odor 1, 2 for task A). Tasks can also share these relevant odors (A vs. B1, A vs. B4) or have different relevant odors (A vs. B2). Furthermore, with the same relevant odors, the task rules can be the same (A vs. B4) or reversed (A vs. B1). It is crucial to quantify these different aspects of relations and identify their impact on CL performance. b Weight-space schematics showing the Gibbs formulation of CL. Each dataset defines a space of solutions where the training loss is zero. The network learns the first task by sampling from its space of solutions. For subsequent tasks, the network assumes different solutions depending on the regularization strength ($\lambda$). At $\lambda=0$, learning of each task samples from its corresponding space of solutions independent of previous learning. At the other extreme of $\lambda\rightarrow\infty$, learning chooses the solution closest to the weights sampled while learning the previous task. These schematics assume $\beta\rightarrow\infty$.
  • Figure 2: Schematics of the OPs and the target-distractor task. a-c Schematics of the OPs. The input features of each task span a P-dimensional subspace in the N dimensional feature space. $\text{span}(\bm{X}_{1}^{L})$ (shown in blue) denotes the space spanned by task 1 input features at the L-th layer, while $\text{span}(\bm{X}_{2}^{L})$ (shown in red) denotes the space spanned by task 2 input features at the L-th layer. $\bm{V}_{1}$ and $\bm{V}_{2}$ are the rule vectors of task 1 and task 2, and by definition lie in $\text{span}(\bm{X}_{1}^{L})$ and $\text{span}(\bm{X}_{2}^{L})$ respectively. $\gamma_{\text{RF}}$ measures how much the rule vectors project onto the shared feature dimensions. In a, c, both rule vectors fully lie in the shared subspace and $\gamma_{\text{RF}}$ is high. In contrast, in b the rule vectors are away from the shared subspace, thus $\gamma_{\text{RF}}$ is small. $\gamma_{\text{rule}}$ measures the similarity between the projection of the rule vectors on to the shared feature dimensions. In a $\gamma_{\text{rule}}$ is high and in b, c $\gamma_{\text{rule}}$ is low. $\gamma_{\text{feature}}$ measures the degree of overlap between the shared feature dimensions, and is low in a, b but high in c. d Schematics of the target-distractor task. Each task consists of a set of P images (rectangles) from CIFAR-100, assigned labels $\pm1$ or 0 (squares). $\rho_{\text{shared}}$ controls the ratio of shared images between two tasks. $\rho_{\text{target}}$ controls the ratio of images with $\pm 1$ labels that are shared between the tasks. For the shared images, some of the labels are flipped between the tasks, and $\rho_{\text{flip}}$ controls the ratio of the images with flipped labels. Varying these parameters allows us to explore the full range of the OPs. e-g Controlling the 3 OPs with the target-distractor task. $\gamma_{\text{feature}}$ depends only on $\rho_{\text{shared}}$ (e), as $\rho_{\text{shared}}$ increases, the overlap between the shared feature subspaces increases, resulting in higher $\gamma_{\text{feature}}$. $\gamma_{\text{RF}}$ depends mainly on $\rho_{\text{target}}$ (f). As $\rho_{\text{target}}$ increases, the rule vectors project more onto the shared feature dimensions, thus $\gamma_{\text{RF}}$ increases. $\gamma_{\text{rule}}$ is tuned by both $\rho_{\text{target}}$ and $\rho_{\text{flip}}$ (g). $\rho_{\text{target}}$ sets an upper bound of $\gamma_{\text{rule}}$, for a fixed $\rho_{\text{target}}$, $\gamma_{\text{rule}}$ decreases with $\rho_{\text{flip}}$. At $\rho_{\text{flip}}=0.5$ about half of the labels are flipped, and $\gamma_{\text{rule}}$ goes to 0.
  • Figure 3: OPs predict short-term and long-term forgetting behaviors in target-distractor sequences. a Forgetting on the training data of the first task after learning two tasks ($F_{2,1}$) is accurately predicted by 2($\gamma_{\text{RF}}-\gamma_{\text{rule}})$,and does not depend on $\gamma_{\text{feature}}$ (represented by the color of the points). Each point represents a target-distractor task sequence with a different set of $(\rho_{\text{target}},\rho_{\text{shared}},\rho_{\text{flip}})$. Inset: $\Delta F_{2,1}$ measures the effect of the regularizer. When $F_{2,1}$ is small ($<0.05$), $\gamma_{\text{RF}}$ and $\gamma_{\text{rule}}$ are close, the effect of the regularizer decreases as $\gamma_{\text{RF}}$ (and $\gamma_{\text{rule}}$) increases, i.e., as the tasks become more similar. b Long-term forgetting in a task sequence can be approximated by an exponential relaxation process, where $F_{\text{max}}$ denotes its asymptote as $t\rightarrow\infty$, and $\tau_{F}$ denotes its time constant (see SI \ref{['sec:exponential']}). We show two examples, one for dissimilar tasks (low $\gamma_{\text{RF}}$, low $\gamma_{\text{rule}}$) with relatively large $\tau_F$ and $F_{\text{max}}$, and the other for similar tasks (high $\gamma_{\text{RF}}$, high $\gamma_{\text{rule}}$), with relatively small $\tau_{F}$ and $F_{\text{max}}$. c Normalized PVE (proportion of variance explained, o1982measures) of $F_{2,1}$ and $\tau_{F}$ by the 3 OPs (see SI \ref{['subsec:targetdistractor']}). $F_{2,1}$ depends on $\gamma_{\text{rule}}$ and $\gamma_{\text{RF}}$ and is independent of $\gamma_{\text{feature}}$, consistent with a. $\tau_F$ mainly depends on $\gamma_{\text{RF}}$ and weakly depends on $\gamma_{\text{rule}}$ and $\gamma_{\text{feature}}$. d For tasks where $F_{2,1}$ is small ($<0.05$), $\tau_{F}$ decreases as $\gamma_{\text{RF}}$ increases. Inset: zoomed-in region of $\gamma_{\text{RF}}<0.05$ highlights the fast decrease of $\tau_{F}$ for small $\gamma_{\text{RF}}$. Data is binned by $\gamma_{\text{RF}}$, errorbars are standard deviations across data points within the same bin. e$F_{\text{max}}$ can be accurately predicted given $F_{2,1}$ and $\tau_{F}$. Each point represents a target-distractor task sequence with a different set of $(\rho_{\text{target}},\rho_{\text{shared}},\rho_{\text{flip}})$. The color represents the density of points. All $F_{t,1}$ and the corresponding OPs are averaged across 40 random seeds used for generating data. See SI \ref{['subsec:targetdistractor']} for detailed parameters.
  • Figure 4: Forgetting in benchmark task sequences and the effect of depth. a-c The effect of depth on short-term ($F_{2,1}$) and long-term ($\tau_{F}$, $F_{\text{max}}$) forgetting on benchmark sequences. For permutation sequences, we averaged over source datasets including MNIST, EMNIST, Fashion-MNIST and CIFAR as their behaviors are similar, and errorbars are standard errors across the source datasets. For split sequences, we show separately split EMNIST and split CIFAR-100 sequences as their behaviors are more different. In all cases, $F_{2,1}$ decreases with depth (a). We look at long-term forgetting only in sequences with small $F_{2,1}$ (5, 10, 15% permutation sequences), $\tau_{F}$ increases with depth (b). Due to the opposing behaviors of $F_{2,1}$ and $\tau_{F}$, $F_{\text{max}}$ does not vary strongly with depth, and there may be an optimal depth where $F_{\text{max}}$ is lowest (c). d Task-relation OPs ($\gamma_{\text{rule}}$ and $\gamma_{\text{RF}}$) on the benchmark sequences ($\gamma_{\text{feature}}$ is not shown as it does not strongly affect either $F_{2,1}$ or $\tau_{F}$, as shown in Fig. \ref{['fig3:distractortask']}). Colors blue, red and green correspond to permutation sequences, split EMNIST and split CIFAR100 respectively. Colors from light to dark correspond to increasing depth of the network. In permutation sequences, larger sizes of the points correspond to larger permutation ratio. The benchmark sequences explore a more constrained region in the OP space compared to the target-distractor sequences. e$F_{2,1}$ is accurately predicted by $2(\gamma_{\text{RF}}-\gamma_{\text{rule}})$, as in Fig. \ref{['fig3:distractortask']}a. Increasing depth or decreasing the permutation ratio results in smaller $\gamma_{\text{RF}}-\gamma_{\text{rule}}$ (as also shown in d), and thus leads to smaller $F_{2,1}$. f For tasks with small $F_{2,1}$ (5,10,15% permutation sequences), $\tau_{F}$ decreases with $\gamma_{\text{RF}}$, consistent with Fig. \ref{['fig3:distractortask']}d. For a fixed permutation ratio, increasing depth results in smaller $\gamma_{\text{RF}}$ (as also shown in d), and thus leads to larger $\tau_{F}$. All $F_{t,1}$ and the corresponding OPs are averaged across 50 random seeds used for generating data. See SI \ref{['sec:benchmark']} for detailed parameters.
  • Figure 5: Multi-head CL exhibits phase transitions in the target-distractor sequence. a Schematics of multi-head CL. Different tasks utilize the same shared hidden-layer weights but different task-specific readouts. The weight-perturbation penalty is only applied to the hidden-layer weights. b Forgetting of task 1 ($F_{2,1}$) and the normalized generalization error on task 2 ($G_{2,2}$) as a function of the network load ($\alpha$) for 2 different sets of $(\rho_{\text{target}},\rho_{\text{shared}},\rho_{\text{flip}})$ in the target-distractor task in the fixed-representation regime (FR, $\alpha<1$). Black arrows indicate divergence towards infinity as $\alpha$ approaches 1. Curves of different colors correspond to tasks with different parameters $(\rho_{\text{target}},\rho_{\text{shared}},\rho_{\text{flip}})$. light: $(1, 0.88, 0), \gamma_{\text{sim}}=0.84$; dark: $(1, 0.58, 0.005), \gamma_{\text{sim}}=0.52$. The generalization errors are calculated on the training data with small perturbations to the input (see SI \ref{['subsec:targetdistractor']}). c The norm of ${\bf a}_{2}$, $\lVert{\bf a}_{2}\rVert^{2}/N$, as a function of $\alpha$ in the fixed-representation regime (FR). Since the hidden layer representations are fixed, learning the second task is equivalent to learning the linear weights ${\bf a}_{2}$ in linear regression, thus the divergence of $G_{2,2}$ results from the divergence of ${\bf a}_{2}$ when approaching the interpolation threshold in linear regression. d Same as b, but for $\alpha>1$. For each combination of $(\rho_{\text{target}},\rho_{\text{shared}},\rho_{\text{flip}})$, $F_{2,1}$ and $G_{2,2}$ exhibit abrupt changes as $\alpha$ crosses a critical load ($\alpha_{c}$, vertical dashed line). In the overfitting regime (OF, $1<\alpha<\alpha_{c}$), $F_{2,1}$ is zero but $G_{2,2}$ diverges. In the generalization regime (G, $\alpha>\alpha_{c}$), both $F_{2,1}$ and $G_{2,2}$ are moderate and nonzero. e Same as c, but for $\alpha>1$. The divergence of $G_{2,2}$ results from the divergence of ${\bf a}_{2}$ to compensate for minimal $\lVert\mathcal{W}_{2}-\mathcal{W}_{1}\rVert$ when learning task 2. f The transition boundary between the fixed-representation regime (FR) and the overfitting regime (OF) is always at $\alpha=1$, and does not depend on the task. The transition boundary between the overfitting regime (OF) and the generalization regime (G), $\alpha_{c}$, can be theoretically predicted by the task similarity metric $\gamma_{\text{sim}}\in[-1,1]$ under reasonable assumptions (SI \ref{['subsec:multiheadtheory']}.\ref{['subsubsec:summarymultihead']}), as shown by the black line. Each red point shows the estimated transition boundary $\alpha_{c}$ from the shape of $F_{2,1}$ (SI \ref{['subsec:targetdistractor']}) for a different combination of $(\rho_{\text{target}},\rho_{\text{shared}},\rho_{\text{flip}})$, and thus a different value of $\gamma_{\text{sim}}$. The red points lie on top of the black curve, demonstrating the accuracy of the theoretical prediction. The light and dark brown points correspond to the lines shown in b and c.
  • ...and 10 more figures