Order parameters and phase transitions of continual learning in deep neural networks
Haozhe Shan, Qianyi Li, Haim Sompolinsky
TL;DR
This work develops a Gibbs-based statistical framework to understand continual learning in deep neural networks, linking forgetting and interference to task relations and architecture through scalar order parameters. By deriving single-head and multi-head theories with generalized, time-dependent kernel functions, the authors predict when forgetting occurs and reveal phase transitions governed by task similarity and network load. The results show depth and width can mitigate forgetting, while multi-head readouts introduce rich phase behavior, including a catastrophic anterograde regime under certain conditions. These insights offer quantitative guidelines for designing neural systems and have implications for understanding how the brain achieves robust continual learning. The framework integrates analytic theory with numerical simulations across diverse task sequences, providing a principled lens for evaluating CL strategies and architectures.
Abstract
Continual learning (CL) enables animals to learn new tasks without erasing prior knowledge. CL in artificial neural networks (NNs) is challenging due to catastrophic forgetting, where new learning degrades performance on older tasks. While various techniques exist to mitigate forgetting, theoretical insights into when and why CL fails in NNs are lacking. Here, we present a statistical-mechanics theory of CL in deep, wide NNs, which characterizes the network's input-output mapping as it learns a sequence of tasks. It gives rise to order parameters (OPs) that capture how task relations and network architecture influence forgetting and anterograde interference, as verified by numerical evaluations. For networks with a shared readout for all tasks (single-head CL), the relevant-feature and rule similarity between tasks, respectively measured by two OPs, are sufficient to predict a wide range of CL behaviors. In addition, the theory predicts that increasing the network depth can effectively reduce interference between tasks, thereby lowering forgetting. For networks with task-specific readouts (multi-head CL), the theory identifies a phase transition where CL performance shifts dramatically as tasks become less similar, as measured by another task-similarity OP. While forgetting is relatively mild compared to single-head CL across all tasks, sufficiently low similarity leads to catastrophic anterograde interference, where the network retains old tasks perfectly but completely fails to generalize new learning. Our results delineate important factors affecting CL performance and suggest strategies for mitigating forgetting.
