Table of Contents
Fetching ...

Leveraging Hierarchical Taxonomies in Prompt-based Continual Learning

Quyen Tran, Hoang Phan, Minh Le, Tuan Truong, Dinh Phung, Linh Ngo, Thien Nguyen, Nhat Ho, Trung Le

TL;DR

It is found that applying human habits of organizing and connecting information can serve as an efficient strategy when training deep learning models and proposes a novel regularization loss function that encourages models to focus more on challenging knowledge areas, thereby enhancing overall performance.

Abstract

Humans perceive the world as a series of sequential events, which can be hierarchically organized with different levels of abstraction based on conceptual knowledge. Drawing inspiration from human learning behaviors, this work proposes a novel approach to mitigate catastrophic forgetting in Prompt-based Continual Learning models by exploiting the relationships between continuously emerging class data. We find that applying human habits of organizing and connecting information can serve as an efficient strategy when training deep learning models. Specifically, by building a hierarchical tree structure based on the expanding set of labels, we gain fresh insights into the data, identifying groups of similar classes could easily cause confusion. Additionally, we delve deeper into the hidden connections between classes by exploring the original pretrained model's behavior through an optimal transport-based approach. From these insights, we propose a novel regularization loss function that encourages models to focus more on challenging knowledge areas, thereby enhancing overall performance. Experimentally, our method demonstrated significant superiority over the most robust state-of-the-art models on various benchmarks.

Leveraging Hierarchical Taxonomies in Prompt-based Continual Learning

TL;DR

It is found that applying human habits of organizing and connecting information can serve as an efficient strategy when training deep learning models and proposes a novel regularization loss function that encourages models to focus more on challenging knowledge areas, thereby enhancing overall performance.

Abstract

Humans perceive the world as a series of sequential events, which can be hierarchically organized with different levels of abstraction based on conceptual knowledge. Drawing inspiration from human learning behaviors, this work proposes a novel approach to mitigate catastrophic forgetting in Prompt-based Continual Learning models by exploiting the relationships between continuously emerging class data. We find that applying human habits of organizing and connecting information can serve as an efficient strategy when training deep learning models. Specifically, by building a hierarchical tree structure based on the expanding set of labels, we gain fresh insights into the data, identifying groups of similar classes could easily cause confusion. Additionally, we delve deeper into the hidden connections between classes by exploring the original pretrained model's behavior through an optimal transport-based approach. From these insights, we propose a novel regularization loss function that encourages models to focus more on challenging knowledge areas, thereby enhancing overall performance. Experimentally, our method demonstrated significant superiority over the most robust state-of-the-art models on various benchmarks.
Paper Structure (39 sections, 9 equations, 12 figures, 5 tables)

This paper contains 39 sections, 9 equations, 12 figures, 5 tables.

Figures (12)

  • Figure 1: Empirical study about forgetting. We setup the experiment to eliminate factors (I) and (II), which result in feature shift after learning the sequence of tasks - Table (a). Therefore, the "within task accuracy" on $\mathcal{D}_1$, using classification head $s_1(\bm{x})$ to classify classes within task 1 only, remains over time - Figure (b). However, when using head $h_\psi$ to classify all classes observed so far, the model accuracy on $\mathcal{D}_1$ decrease significantly - Figure (c), suggesting that besides feature shift, there are other factors that lead to forgetting.
  • Figure 2: Problem: When new classes arrive, the latent space of a model for all tasks so far becomes fuller, and class representations tend to be overlapped, leading to performance degradation in old tasks. Our solution: We focus on separating easily confused classes whose concepts/labels lie in the same leaf group on a label-based hierarchical taxonomy, suggested by expert knowledge (i.e., ChatGPT - Figure \ref{['fig:motivation_taxonomy']}). [Best viewed in color mode]
  • Figure 3: The label-based hierarchical taxonomy obtained when learning Task 3 on Split-CIFAR100. The colors (i.e., blue, red, orange) of the label names represent the order in which the corresponding classes appear. Accordingly, the tree-like taxonomy is gradually developed and detailed over time. [Best viewed in color mode]
  • Figure 4: t-SNE visualizations of classes within leaf groups of Four-legged animals ($\bullet$ circular points) and Plants ($\blacktriangle$ triangular points) when learning Task 3, Split-CIFAR-100. The appearance order of the classes: Task 1 - "mouse", "porcupine", "oak tree"; Task 2 - "willow tree"; Task 3 - "otter", "hamster", "pine tree" (also refer to Figure \ref{['fig:train']} and \ref{['fig:motivation_taxonomy']}). We can see that if we train tasks independently with $\mathcal{L}_{CE}$, the classes within each leaf group, which arrive at different time, can be overlapped seriously (Figure \ref{['fig:L_CE']}).
  • Figure 5: L2-Wassertein distance between classes (Split-CIFAR-100) in latent space of a pretrained backbone (Sup-21K). Within a leaf group (i.e., "plant" or "four-legged mammals"), there may be data classes with varying levels of correlation in the latent space.
  • ...and 7 more figures