Table of Contents
Fetching ...

Efficient Continual Learning with Modular Networks and Task-Driven Priors

Tom Veniat, Ludovic Denoyer, Marc'Aurelio Ranzato

TL;DR

This work reframes continual learning by emphasizing not only memory of past tasks but also transfer and scalability across long task streams. It introduces the CTrL benchmark to probe transfer and efficiency, and proposes Modular Networks with a Task-Driven Prior (MNTDP), a modular architecture that reuses past modules and adds new ones for new tasks under a data-driven prior to keep the search space manageable. The approach demonstrates competitive performance on standard CL benchmarks and superior transfer and scalability on CTrL, with a clear trade-off between prior breadth and computational budget. This modular, prior-guided paradigm offers practical paths toward scalable, transferable lifelong learning in real-world long-tail task sequences.

Abstract

Existing literature in Continual Learning (CL) has focused on overcoming catastrophic forgetting, the inability of the learner to recall how to perform tasks observed in the past. There are however other desirable properties of a CL system, such as the ability to transfer knowledge from previous tasks and to scale memory and compute sub-linearly with the number of tasks. Since most current benchmarks focus only on forgetting using short streams of tasks, we first propose a new suite of benchmarks to probe CL algorithms across these new axes. Finally, we introduce a new modular architecture, whose modules represent atomic skills that can be composed to perform a certain task. Learning a task reduces to figuring out which past modules to re-use, and which new modules to instantiate to solve the current task. Our learning algorithm leverages a task-driven prior over the exponential search space of all possible ways to combine modules, enabling efficient learning on long streams of tasks. Our experiments show that this modular architecture and learning algorithm perform competitively on widely used CL benchmarks while yielding superior performance on the more challenging benchmarks we introduce in this work.

Efficient Continual Learning with Modular Networks and Task-Driven Priors

TL;DR

This work reframes continual learning by emphasizing not only memory of past tasks but also transfer and scalability across long task streams. It introduces the CTrL benchmark to probe transfer and efficiency, and proposes Modular Networks with a Task-Driven Prior (MNTDP), a modular architecture that reuses past modules and adds new ones for new tasks under a data-driven prior to keep the search space manageable. The approach demonstrates competitive performance on standard CL benchmarks and superior transfer and scalability on CTrL, with a clear trade-off between prior breadth and computational budget. This modular, prior-guided paradigm offers practical paths toward scalable, transferable lifelong learning in real-world long-tail task sequences.

Abstract

Existing literature in Continual Learning (CL) has focused on overcoming catastrophic forgetting, the inability of the learner to recall how to perform tasks observed in the past. There are however other desirable properties of a CL system, such as the ability to transfer knowledge from previous tasks and to scale memory and compute sub-linearly with the number of tasks. Since most current benchmarks focus only on forgetting using short streams of tasks, we first propose a new suite of benchmarks to probe CL algorithms across these new axes. Finally, we introduce a new modular architecture, whose modules represent atomic skills that can be composed to perform a certain task. Learning a task reduces to figuring out which past modules to re-use, and which new modules to instantiate to solve the current task. Our learning algorithm leverages a task-driven prior over the exponential search space of all possible ways to combine modules, enabling efficient learning on long streams of tasks. Our experiments show that this modular architecture and learning algorithm perform competitively on widely used CL benchmarks while yielding superior performance on the more challenging benchmarks we introduce in this work.

Paper Structure

This paper contains 34 sections, 5 equations, 15 figures, 17 tables, 2 algorithms.

Figures (15)

  • Figure 1: Comparison of various CL methods on the CTrL benchmark using Resnet (left) and Alexnet (right) backbones. MNTDP-D is our method. See Tab. \ref{['tab:res-all-metrics']} of §\ref{['sec:old_bench']} for details.
  • Figure 2: Toy illustration of the approach when each predictor is composed of only three modules and only two tasks have already been observed. A): The predictor of the first task uses modules (1,1,1) (listing modules by increasing depth in the network) while the predictor of the second task uses modules (1,2,2); the first layer module is shared between the two predictors. B): When a new task arrives, first we add one new randomly initialized module at each layer (the dashed modules). Second, we search for the most similar past task and retain only the corresponding architecture. In this case, the second task is most similar and therefore we remove (gray out) the modules used only by the predictor of the first task. C): We train on the current task by learning both the best way to combine modules and their parameters. However, we restrict the search space. In this case, we only consider four possible compositions, all derived by perturbing the predictor of the second task. In the stochastic version (MNTDP-S), for every input a path (sequence of modules) is selected stochastically. Notice that the same module may contribute to multiple paths (e.g., the top-most layer with id 3). In the deterministic version instead (MNTDP-D), we train in parallel all paths and then select the best. Note that only the parameters of the newly added (dashed) modules are subject to learning. D): Assuming that the best architecture found at the previous step is (1,2,3), module 3 at the top layer is added to the current library of modules.
  • Figure 3: Results on standard continual learning streams. * denotes an Alexnet Backbone. $\dagger$ correspond to models cross-validated at the stream-level, a setting that favors them over the other methods which are cross-validated at the task-level. Detailed results are presented in Appendix \ref{['app:classic_streams']}.
  • Figure 4: Results on the long evaluation stream. * correspond to models using an Alexnet backbone. See Tab. \ref{['tab:res-long-stream-full']} for more baselines and error bars.
  • Figure 5: Evolution of $<\mathcal{A}>$ and Mem. on $\mathcal{S}^{\hbox{\tiny{long}}}$.
  • ...and 10 more figures