Table of Contents
Fetching ...

Drift to Remember

Jin Du, Xinhe Zhang, Hao Shen, Xun Xian, Ganghua Wang, Jiawei Zhang, Yuhong Yang, Na Li, Jia Liu, Jie Ding

TL;DR

The paper tackles catastrophic forgetting in lifelong learning by leveraging representational drift, inspired by biological neural dynamics. It introduces DriftNet, a drift-driven framework that continuously explores local minima via external noise, encodes them into task-specific groups in a knowledge base, and retrieves relevant knowledge using uncertainty-based selection. Across simulated tasks, CIFAR-10/100, and NLP with GPT-2 integration, DriftNet outperforms the Stable baseline and approaches joint/theoretical-limits performance, while being scalable to billions-parameter LLMs on a single Nvidia A100 GPU and using only new data. The approach offers a general, scalable mechanism for continual learning with potential insights into biological learning and broad applicability to multi-domain, real-time AI systems.

Abstract

Lifelong learning in artificial intelligence (AI) aims to mimic the biological brain's ability to continuously learn and retain knowledge, yet it faces challenges such as catastrophic forgetting. Recent neuroscience research suggests that neural activity in biological systems undergoes representational drift, where neural responses evolve over time, even with consistent inputs and tasks. We hypothesize that representational drift can alleviate catastrophic forgetting in AI during new task acquisition. To test this, we introduce DriftNet, a network designed to constantly explore various local minima in the loss landscape while dynamically retrieving relevant tasks. This approach ensures efficient integration of new information and preserves existing knowledge. Experimental studies in image classification and natural language processing demonstrate that DriftNet outperforms existing models in lifelong learning. Importantly, DriftNet is scalable in handling a sequence of tasks such as sentiment analysis and question answering using large language models (LLMs) with billions of parameters on a single Nvidia A100 GPU. DriftNet efficiently updates LLMs using only new data, avoiding the need for full dataset retraining. Tested on GPT-2 and RoBERTa, DriftNet is a robust, cost-effective solution for lifelong learning in LLMs. This study not only advances AI systems to emulate biological learning, but also provides insights into the adaptive mechanisms of biological neural systems, deepening our understanding of lifelong learning in nature.

Drift to Remember

TL;DR

The paper tackles catastrophic forgetting in lifelong learning by leveraging representational drift, inspired by biological neural dynamics. It introduces DriftNet, a drift-driven framework that continuously explores local minima via external noise, encodes them into task-specific groups in a knowledge base, and retrieves relevant knowledge using uncertainty-based selection. Across simulated tasks, CIFAR-10/100, and NLP with GPT-2 integration, DriftNet outperforms the Stable baseline and approaches joint/theoretical-limits performance, while being scalable to billions-parameter LLMs on a single Nvidia A100 GPU and using only new data. The approach offers a general, scalable mechanism for continual learning with potential insights into biological learning and broad applicability to multi-domain, real-time AI systems.

Abstract

Lifelong learning in artificial intelligence (AI) aims to mimic the biological brain's ability to continuously learn and retain knowledge, yet it faces challenges such as catastrophic forgetting. Recent neuroscience research suggests that neural activity in biological systems undergoes representational drift, where neural responses evolve over time, even with consistent inputs and tasks. We hypothesize that representational drift can alleviate catastrophic forgetting in AI during new task acquisition. To test this, we introduce DriftNet, a network designed to constantly explore various local minima in the loss landscape while dynamically retrieving relevant tasks. This approach ensures efficient integration of new information and preserves existing knowledge. Experimental studies in image classification and natural language processing demonstrate that DriftNet outperforms existing models in lifelong learning. Importantly, DriftNet is scalable in handling a sequence of tasks such as sentiment analysis and question answering using large language models (LLMs) with billions of parameters on a single Nvidia A100 GPU. DriftNet efficiently updates LLMs using only new data, avoiding the need for full dataset retraining. Tested on GPT-2 and RoBERTa, DriftNet is a robust, cost-effective solution for lifelong learning in LLMs. This study not only advances AI systems to emulate biological learning, but also provides insights into the adaptive mechanisms of biological neural systems, deepening our understanding of lifelong learning in nature.
Paper Structure (13 sections, 18 equations, 13 figures, 5 algorithms)

This paper contains 13 sections, 18 equations, 13 figures, 5 algorithms.

Figures (13)

  • Figure 1: Drift and catastrophic forgetting in lifelong learning.a, Illustrations of catastrophic forgetting. In a dynamic environment with the continual arrival of new tasks, the brain (top) can learn new tasks while retaining previously acquired knowledge. In contrast, a neural network (bottom) -- when fine-tuned continuously on new tasks -- learns the new tasks but gradually forgets the earlier learned tasks. b, Schematics of representational drift in biological systems driscoll2022representational. The activity patterns of the same neurons that represent the same task constantly change over time. c, Illustration of a drift-inspired neural network. The weights of the network drift adaptively over time to avoid overwriting old information while learning new tasks, thus preserving performance on previously learned tasks.
  • Figure 2: Drift-inspired mechanism to prevent catastrophic forgetting in lifelong learning.a, Schematic illustration of stable and drifting networks. A stable network continuously learns new tasks, but overwrites previously acquired knowledge. In contrast, a drifting network allows its representations to continuously drift, thereby avoiding overwriting when learning new tasks. b-e, Implementation of the neural representational drift-inspired lifelong learning algorithm, b, Schematic overview showing the three steps of DriftNet: exploration (left), encoding (middle), and retrieval (right). DriftNet features an evolving model for exploration and a knowledge base for encoding and retrieving grouped task-specific information. c, Exploration step. Enabled by external noise, instead of (i) remaining fixed at a single local minimum like a stable network, the drifting network (ii) explores alternate minima in the current loss landscape. d, Encoding step. In a stable network, newly learned local minima overwrite previous ones, leading to gradual forgetting of previous tasks. In contrast, DriftNet constantly organizes and groups Task 1-specific local minima in the knowledge base during the learning of Task 2, preventing them from being overwritten by the newly learned minima from Task 2. e, Retrieval step. A stable network (i) cannot identify Task 1 information due to the forgetting caused by overwriting. In contrast, DriftNet (ii) can still identify the grouped Task 1-specific information after learning a new task, enabled by its drifting characteristic.
  • Figure 3: Benchmarking DriftNet lifelong learning performance on simulated datasets.a-d, Statistical results of DriftNet lifelong learning performance on simulated linear regression datasets. $n = 50$ experimental replicates. a, Boxplots with density plots of average test loss, showing the average test loss of two tasks relative to noise scales. The gray dotted line represents the stable baseline. b, Boxplots with density plots of drift rate, showing the average test loss of two tasks relative to noise scales. The gray dotted line represents the stable baseline. The black line represents the locally weighted scatter plot smoothing (LOWESS) curve of the average test loss with a fraction of $0.3$. c, Statistical summary of test losses for two tasks relative to epoch for drift (left) and stable (right) networks, respectively. The value represents the mean $\pm$ SE, $n = 50$ experimental replicates. d, Statistical summary of training losses with different noise levels ($\sigma$) relative to the epoch of drift (top) and stable networks (bottom), respectively. The value represents the mean $\pm$ SE, $n = 50$ experimental replicates. e, Scatter dots plot showing the trajectory of two model weights $(\beta_1,\beta_2)$ over time. The top plot contains points over the epoch, $(\beta_1,\beta_2, \text{epoch})$, in 3d space; and the bottom plot contains $(\beta_1,\beta_2)$ data for all epochs. The orange and blue lines indicate the theoretical minima manifolds of Tasks 1 and 2, respectively. $\sigma=3$. f, Scatter dot plots showing performance vectors of minima mapped onto the first two principal components (PCs). h, Boxplot showing the retrieval accuracy relative to noise scales ($\sigma$). The dashed green line indicates the stable baseline. g, Boxplot showing the uncertainty of task-specific groups of local minima, evaluated on batch of input data of the relevant task (in-distribution), and irrelevant task (out-distribution). $\sigma=0.001$, batch size $16$. Box, $75\%$ and $25\%$ quantiles. Line, median. Whisker, median $\pm$$1.5\times$ interquartile range (IQR). $n=50$ experimental replicates. ****p-values of the Mann-Whitney $U$ test and the student t-test are less than $10^{-4}$.
  • Figure 4: Image classification.a, Statistical summary of the average test accuracy of all tasks relative to the number of seen tasks, for CIFAR-10 (left) and CIFAR-100 (right). The value represents the mean $\pm$ SE, $n = 10$ experimental replicates. b-f, Statistical results of DriftNet performance on CIFAR-10. b, Statistical summary of the training loss with different noise scales ($\sigma$) relative to epoch. The value represents the mean $\pm$ SE, $n = 10$ experimental replicates. c, Scatter dots plot showing feature drifts of the first category, projected on the first two principal components (PC) of the feature drifts (see Methods), at epochs $50$ and $100$. $\sigma=0.001$. d, Scatter dots plot showing the performance vectors of the minima, evaluated in the buffer (see Methods), mapped to the first three principal components. $\sigma=0.001$. e, Boxplot showing the retrieval accuracy relative to noise scales ($\sigma$) (see Methods). f, Boxplot showing the uncertainty of task-specific groups of local minima, evaluated on batch of input data from the relevant task (in-distribution), and irrelevant task (out-distribution), see Methods. $\sigma=0.001$, batch size $16$. Box, $75\%$ and $25\%$ quantiles. Line, median. Whisker, the most extreme data point within the median $\pm$$1.5\times$ interquartile range (IQR). $n=10$ experimental replicates. ****p-values of the Mann-Whitney $U$ test and the student t-test are less than $10^{-4}$.
  • Figure 5: Large language model GPT-2 applications.a, Illustration of the lifelong learning chatbot (DriftNet). After learning four language tasks, the chatbot retrieves the relevant set of minima to provide accurate responses. b, Barplots showing the test accuracy for each of four tasks during their learning process. The bar represents the mean $\pm$ SE, $n=5$ experimental replicates. c, Barplots showing the difference in test accuracy between methods and Theoretical Limits for each of the four tasks during the learning process. The bar represents the mean $\pm$ SE, $n=5$ experimental replicates. Red boxes indicate Drift networks learning a new task with mild forgetting. Green boxes indicate Stable networks forgetting previous tasks. d, Statistical summary of the average test accuracy of all tasks relative to the number of seen tasks, for different methods. The value represents the mean $\pm$ SE, $n = 5$ experimental replicates. The gray dashed line represents the Joint baseline, where a pre-trained large language model is fine-tuned offline using all task data. e, Boxplots showing the uncertainty of task-specific groups of local minima, evaluated on batch of input data from the relevant task (in-distribution), and irrelevant task (out-of-distribution), see Methods. Batch size $16$. Box, $75\%$ and $25\%$ quantiles. Line, median. Whisker, the most extreme data point within the median $\pm$$1.5\times$ interquartile range (IQR). $n=5$ experimental replicates.
  • ...and 8 more figures