Table of Contents
Fetching ...

Class incremental learning with probability dampening and cascaded gated classifier

Jary Pomponi, Alessio Devoto, Simone Scardapane

TL;DR

This work tackles class incremental learning by addressing forgetting through Margin Dampening (MD) and a novel Cascaded Gates (CG) classifier. MD imposes a margin-based constraint on past-task probabilities and adds a memory-based knowledge-distillation regulariser, while CG ensembles scaled, task-wise heads via gating to preserve past knowledge without directly disturbing previous outputs. Empirical results across CIFAR and TinyImageNet variants show MD-CG achieves a favorable stability-plasticity balance and outperforms a wide range of baselines, with modest memory overhead. The approach offers a practical, memory-efficient path for continual learning in realistic task streams, though CG growth is quadratic in the number of tasks, suggesting avenues for theoretical analysis and further scalability improvements.

Abstract

Humans are capable of acquiring new knowledge and transferring learned knowledge into different domains, incurring a small forgetting. The same ability, called Continual Learning, is challenging to achieve when operating with neural networks due to the forgetting affecting past learned tasks when learning new ones. This forgetting can be mitigated by replaying stored samples from past tasks, but a large memory size may be needed for long sequences of tasks; moreover, this could lead to overfitting on saved samples. In this paper, we propose a novel regularisation approach and a novel incremental classifier called, respectively, Margin Dampening and Cascaded Scaling Classifier. The first combines a soft constraint and a knowledge distillation approach to preserve past learned knowledge while allowing the model to learn new patterns effectively. The latter is a gated incremental classifier, helping the model modify past predictions without directly interfering with them. This is achieved by modifying the output of the model with auxiliary scaling functions. We empirically show that our approach performs well on multiple benchmarks against well-established baselines, and we also study each component of our proposal and how the combinations of such components affect the final results.

Class incremental learning with probability dampening and cascaded gated classifier

TL;DR

This work tackles class incremental learning by addressing forgetting through Margin Dampening (MD) and a novel Cascaded Gates (CG) classifier. MD imposes a margin-based constraint on past-task probabilities and adds a memory-based knowledge-distillation regulariser, while CG ensembles scaled, task-wise heads via gating to preserve past knowledge without directly disturbing previous outputs. Empirical results across CIFAR and TinyImageNet variants show MD-CG achieves a favorable stability-plasticity balance and outperforms a wide range of baselines, with modest memory overhead. The approach offers a practical, memory-efficient path for continual learning in realistic task streams, though CG growth is quadratic in the number of tasks, suggesting avenues for theoretical analysis and further scalability improvements.

Abstract

Humans are capable of acquiring new knowledge and transferring learned knowledge into different domains, incurring a small forgetting. The same ability, called Continual Learning, is challenging to achieve when operating with neural networks due to the forgetting affecting past learned tasks when learning new ones. This forgetting can be mitigated by replaying stored samples from past tasks, but a large memory size may be needed for long sequences of tasks; moreover, this could lead to overfitting on saved samples. In this paper, we propose a novel regularisation approach and a novel incremental classifier called, respectively, Margin Dampening and Cascaded Scaling Classifier. The first combines a soft constraint and a knowledge distillation approach to preserve past learned knowledge while allowing the model to learn new patterns effectively. The latter is a gated incremental classifier, helping the model modify past predictions without directly interfering with them. This is achieved by modifying the output of the model with auxiliary scaling functions. We empirically show that our approach performs well on multiple benchmarks against well-established baselines, and we also study each component of our proposal and how the combinations of such components affect the final results.
Paper Structure (32 sections, 12 equations, 10 figures, 5 tables)

This paper contains 32 sections, 12 equations, 10 figures, 5 tables.

Figures (10)

  • Figure 1: We train a ResNet-20 model with a simple rehearsal approach and a memory size of 500, on CIFAR10 divided into 5 tasks. A) Magnitude of ground truth logits on the replay samples; B) Forgetting on the replay samples; C) Forgetting on test scores for each past task.
  • Figure 2: On the left we show the probability dampening schema (MD, Section \ref{['sec:mbm']}), in which past probabilities from past learned tasks, $p^{1:2}(x)$, are decreased up to a certain margin $m$ with respect to the ground truth probability, $p^{3:3}_y(x)$. On the right, we visualise the cascaded procedure (CG, Section \ref{['sec:scaler']}), which combines scaled task-wise outputs to build the final prediction $f^3_\text{CG}(x)$. Both components are visualised for a three-task scenario, in which the last one is the training one. Better viewed in colours.
  • Figure 3: How the accuracy (left) and the BTW (right) scores are affected when varying the memory and the past margin regularisation. The results are obtained on ResNet-20 trained on C10-5.
  • Figure 4: The additional floats required by the approaches against the achieved accuracy. As additional floats, we count all the pixels in the memory, as well as additional parameters that a method requires.
  • Figure 5: The cosine distance between gradients produced by the loss calculated on rehearsal samples and the ones related to the Knowledge Distillation regularisation. The last layer (last row of the heatmap) is the one which constantly has the two gradients pointing toward different directions.
  • ...and 5 more figures