Table of Contents
Fetching ...

Approximate Bayesian Class-Conditional Models under Continuous Representation Shift

Thomas L. Lee, Amos Storkey

TL;DR

This work tackles representation shift in online continual learning by proposing DeepCCG, an empirical Bayesian framework that stacks a class-conditional Gaussian classifier on a neural embedding. The classifier's posterior over class means can be updated in one step when representations shift, while the embedding is trained with a log conditional marginal likelihood that promotes alignment to the classifier. A memory-based sample-selection scheme minimizes information loss by preserving posterior structure, enabling robust tracking of distributional shift. Experiments on CIFAR-10, CIFAR-100, and MiniImageNet show DeepCCG achieving state-of-the-art performance in both task- and class-incremental settings and demonstrating strong robustness to representation shift.

Abstract

For models consisting of a classifier in some representation space, learning online from a non-stationary data stream often necessitates changes in the representation. So, the question arises of what is the best way to adapt the classifier to shifts in representation. Current methods only slowly change the classifier to representation shift, introducing noise into learning as the classifier is misaligned to the representation. We propose DeepCCG, an empirical Bayesian approach to solve this problem. DeepCCG works by updating the posterior of a class conditional Gaussian classifier such that the classifier adapts in one step to representation shift. The use of a class conditional Gaussian classifier also enables DeepCCG to use a log conditional marginal likelihood loss to update the representation. To perform the update to the classifier and representation, DeepCCG maintains a fixed number of examples in memory and so a key part of DeepCCG is selecting what examples to store, choosing the subset that minimises the KL divergence between the true posterior and the posterior induced by the subset. We explore the behaviour of DeepCCG in online continual learning (CL), demonstrating that it performs well against a spectrum of online CL methods and that it reduces the change in performance due to representation shift.

Approximate Bayesian Class-Conditional Models under Continuous Representation Shift

TL;DR

This work tackles representation shift in online continual learning by proposing DeepCCG, an empirical Bayesian framework that stacks a class-conditional Gaussian classifier on a neural embedding. The classifier's posterior over class means can be updated in one step when representations shift, while the embedding is trained with a log conditional marginal likelihood that promotes alignment to the classifier. A memory-based sample-selection scheme minimizes information loss by preserving posterior structure, enabling robust tracking of distributional shift. Experiments on CIFAR-10, CIFAR-100, and MiniImageNet show DeepCCG achieving state-of-the-art performance in both task- and class-incremental settings and demonstrating strong robustness to representation shift.

Abstract

For models consisting of a classifier in some representation space, learning online from a non-stationary data stream often necessitates changes in the representation. So, the question arises of what is the best way to adapt the classifier to shifts in representation. Current methods only slowly change the classifier to representation shift, introducing noise into learning as the classifier is misaligned to the representation. We propose DeepCCG, an empirical Bayesian approach to solve this problem. DeepCCG works by updating the posterior of a class conditional Gaussian classifier such that the classifier adapts in one step to representation shift. The use of a class conditional Gaussian classifier also enables DeepCCG to use a log conditional marginal likelihood loss to update the representation. To perform the update to the classifier and representation, DeepCCG maintains a fixed number of examples in memory and so a key part of DeepCCG is selecting what examples to store, choosing the subset that minimises the KL divergence between the true posterior and the posterior induced by the subset. We explore the behaviour of DeepCCG in online continual learning (CL), demonstrating that it performs well against a spectrum of online CL methods and that it reduces the change in performance due to representation shift.
Paper Structure (15 sections, 15 equations, 4 figures, 6 tables, 1 algorithm)

This paper contains 15 sections, 15 equations, 4 figures, 6 tables, 1 algorithm.

Figures (4)

  • Figure 1: Diagram showing that when updating a model on new data current online continual learning methods only slowly adapt the classifier, i.e. decision boundary in representation space, to representation shift. On the other hand, DeepCCG quickly adapts the decision boundary, improving learning as the classifier is better matched to the current representation. This is illustrated in the diagram as for DeepCCG the decision boundary is adjusted in a single update such that the shifted representations, which are pointed to by arrows, all remain in the correct regions while for standard CL methods this is not the case.
  • Figure 2: Diagram of DeepCCG's training routine. At time $j$ the learner is given a sample of data $B_j$ and has a memory of stored datapoints $M_j$. The memory is randomly split into replay data $R_j$ and the rest $M_j/R_j$. Learning happens by taking a gradient step on the parameters of the embedding function $\boldsymbol{\phi}$ using a log conditional marginal likelihood function over $B_j$ and $R_j$, where $M_j/R_j$ is used to induce a posterior over the means of the per-class Gaussians and so define the conditional marginal likelihood function used. Therefore, training aims to move the data points into per-class clusters, by drawing the embedded examples of $B_j$ and $R_j$ towards their own class means and away from the other class means, as shown by the coloured arrows in the figure.
  • Figure 3: Binned scatter plot showing for the MiniImageNet task-incremental disjoint-tasks setting the change in accuracy against the mean change in representation after learning on a batch for the test data of the first task. The plot shows that for a given shift in representation the accuracy of DeepCCG changes the least.
  • Figure 4: Binned scatter plot showing for the MiniImageNet disjoint tasks setting the change in accuracy against the mean change in representation after learning on a batch for the test data of the first task.