Table of Contents
Fetching ...

Revisiting Softmax Masking: Stop Gradient for Enhancing Stability in Replay-based Continual Learning

Hoyong Kim, Minchan Kwon, Kangil Kim

TL;DR

The paper addresses catastrophic forgetting in replay-based continual learning by focusing on the pull-push dynamics induced by cross-entropy with softmax. It revisits softmax masking and introduces a general masked softmax that replaces non-current-task logits with a mask value m and stops gradient flow on masked entries, enabling explicit control of gradient flow and stability. The authors show that negative infinity masking ($m=- inf$) can boost stability but may conflict with dark knowledge, and they propose a flexible masking strategy that balances stability and plasticity; distillation with masked softmax can be dangerous, so they emphasize careful use. Across standard CL benchmarks and low-buffer scenarios, the method improves final accuracy and reduces forgetting, with tunable masking values offering a practical handle on the stability-plasticity trade-off and applicability to extremely small episodic memories. Overall, the work provides a principled mechanism to control inter-task interference in replay-based CL and demonstrates its effectiveness in improving robustness when memory is severely limited.

Abstract

In replay-based methods for continual learning, replaying input samples in episodic memory has shown its effectiveness in alleviating catastrophic forgetting. However, the potential key factor of cross-entropy loss with softmax in causing catastrophic forgetting has been underexplored. In this paper, we analyze the effect of softmax and revisit softmax masking with negative infinity to shed light on its ability to mitigate catastrophic forgetting. Based on the analyses, it is found that negative infinity masked softmax is not always compatible with dark knowledge. To improve the compatibility, we propose a general masked softmax that controls the stability by adjusting the gradient scale to old and new classes. We demonstrate that utilizing our method on other replay-based methods results in better performance, primarily by enhancing model stability in continual learning benchmarks, even when the buffer size is set to an extremely small value.

Revisiting Softmax Masking: Stop Gradient for Enhancing Stability in Replay-based Continual Learning

TL;DR

The paper addresses catastrophic forgetting in replay-based continual learning by focusing on the pull-push dynamics induced by cross-entropy with softmax. It revisits softmax masking and introduces a general masked softmax that replaces non-current-task logits with a mask value m and stops gradient flow on masked entries, enabling explicit control of gradient flow and stability. The authors show that negative infinity masking () can boost stability but may conflict with dark knowledge, and they propose a flexible masking strategy that balances stability and plasticity; distillation with masked softmax can be dangerous, so they emphasize careful use. Across standard CL benchmarks and low-buffer scenarios, the method improves final accuracy and reduces forgetting, with tunable masking values offering a practical handle on the stability-plasticity trade-off and applicability to extremely small episodic memories. Overall, the work provides a principled mechanism to control inter-task interference in replay-based CL and demonstrates its effectiveness in improving robustness when memory is severely limited.

Abstract

In replay-based methods for continual learning, replaying input samples in episodic memory has shown its effectiveness in alleviating catastrophic forgetting. However, the potential key factor of cross-entropy loss with softmax in causing catastrophic forgetting has been underexplored. In this paper, we analyze the effect of softmax and revisit softmax masking with negative infinity to shed light on its ability to mitigate catastrophic forgetting. Based on the analyses, it is found that negative infinity masked softmax is not always compatible with dark knowledge. To improve the compatibility, we propose a general masked softmax that controls the stability by adjusting the gradient scale to old and new classes. We demonstrate that utilizing our method on other replay-based methods results in better performance, primarily by enhancing model stability in continual learning benchmarks, even when the buffer size is set to an extremely small value.
Paper Structure (30 sections, 7 equations, 4 figures, 6 tables)

This paper contains 30 sections, 7 equations, 4 figures, 6 tables.

Figures (4)

  • Figure 1: Relation between Catastrophic Forgetting and Cross-Entropy Loss with Softmax in Continual Learning.
  • Figure 2: Comparison Experiments between baseline and masked softmax in class-incremental learning. (1st row) Confidence of each classes and accuracy of each tasks after the final task. Confidence is measured by the output of the softmax function from the logits of input samples. (2nd row) Change in average accuracy for each task as the models are trained up to the $N$-th task. Task-wise average accuracy is defined as the mean of accuracies of each class in the task. The dataset and methods used in each experiment are described in the title and caption. (Acc: task-wise average test accuracy, T$n$: $n$-th task, red: higher than baseline, blue: lower than baseline, FAA: Final Average Accuracy, Class-IL: Class-Incremental Learning, Task-IL: Task-Incremental Learning, bold: best FAA)
  • Figure 3: Overview of general masked softmax. First, make masked logits ${\mathbf{z}}$ by replacing the logits of old and new classes to a masking value $m$. Second, apply softmax function on the masked logits and achieve their confidence ${\mathbf{p}}$. Lastly, backward the loss while stop-gradient on replaced logits.
  • Figure 4: Change in final average accuracy according to the masking value. All experiments were conducted using 10 trials with random seeds. The mean and standard deviation of final average accuracy are represented as line and band, respectively. ($\mathcal{B}$: buffer size, dotted line: the final average accuracy of the baseline - DER++)