Table of Contents
Fetching ...

Thalamus: a brain-inspired algorithm for biologically-plausible continual learning and disentangled representations

Ali Hummos

TL;DR

A simple algorithm that uses optimization at inference time to generate internal representations of the current task dynamically, and suggests fundamental computations in circuits with abundant feedback control loops such as the thalamocortical circuits in the brain.

Abstract

Animals thrive in a constantly changing environment and leverage the temporal structure to learn well-factorized causal representations. In contrast, traditional neural networks suffer from forgetting in changing environments and many methods have been proposed to limit forgetting with different trade-offs. Inspired by the brain thalamocortical circuit, we introduce a simple algorithm that uses optimization at inference time to generate internal representations of the current task dynamically. The algorithm alternates between updating the model weights and a latent task embedding, allowing the agent to parse the stream of temporal experience into discrete events and organize learning about them. On a continual learning benchmark, it achieves competitive end average accuracy by mitigating forgetting, but importantly, by requiring the model to adapt through latent updates, it organizes knowledge into flexible structures with a cognitive interface to control them. Tasks later in the sequence can be solved through knowledge transfer as they become reachable within the well-factorized latent space. The algorithm meets many of the desiderata of an ideal continually learning agent in open-ended environments, and its simplicity suggests fundamental computations in circuits with abundant feedback control loops such as the thalamocortical circuits in the brain.

Thalamus: a brain-inspired algorithm for biologically-plausible continual learning and disentangled representations

TL;DR

A simple algorithm that uses optimization at inference time to generate internal representations of the current task dynamically, and suggests fundamental computations in circuits with abundant feedback control loops such as the thalamocortical circuits in the brain.

Abstract

Animals thrive in a constantly changing environment and leverage the temporal structure to learn well-factorized causal representations. In contrast, traditional neural networks suffer from forgetting in changing environments and many methods have been proposed to limit forgetting with different trade-offs. Inspired by the brain thalamocortical circuit, we introduce a simple algorithm that uses optimization at inference time to generate internal representations of the current task dynamically. The algorithm alternates between updating the model weights and a latent task embedding, allowing the agent to parse the stream of temporal experience into discrete events and organize learning about them. On a continual learning benchmark, it achieves competitive end average accuracy by mitigating forgetting, but importantly, by requiring the model to adapt through latent updates, it organizes knowledge into flexible structures with a cognitive interface to control them. Tasks later in the sequence can be solved through knowledge transfer as they become reachable within the well-factorized latent space. The algorithm meets many of the desiderata of an ideal continually learning agent in open-ended environments, and its simplicity suggests fundamental computations in circuits with abundant feedback control loops such as the thalamocortical circuits in the brain.
Paper Structure (18 sections, 5 equations, 17 figures, 3 tables, 1 algorithm)

This paper contains 18 sections, 5 equations, 17 figures, 3 tables, 1 algorithm.

Figures (17)

  • Figure 1: Model schematic. Controlled triggering of latent updates and weight updates and alternating them allows for unsupervised task discovery. Latent updates retrieve previously learned tasks or choose a new embedding for new ones (PFC: prefrontal cortex, MD: mediodorsal thalamus)
  • Figure 2: (A) Schematic of the experiment, model pretrained with task IDs and transitioning to identifying tasks through gradients backpropagated from the loss function to update the task embedding z (latent updates). (B) Current accuracy as the latent is updated repeatedly using only one batch from the current task. We switch to the next task once accuracy criterion is reached. (C) current values of the latent task embedding vector z. Circular markers indicate the one-hot task ID used during pretraining. (D) Gradients of loss with respect to z. (E) Average accuracy on all tasks if we were to use the original task IDs used during pretraining, shows no forgetting of previous tasks.
  • Figure 3: Dimensionality reduction with TSNE of the latent vectors after latent updates on each task. Data point colored by the ground truth task ID
  • Figure 4: Thalamus performance on an unlabeled sequence of tasks. A) Comparing number of weight updates needed to solve tasks along the training sequence for Thalamus and baseline RNN. Thalamus reliance on weight updates begins to vanish. B) Latent updates needed for each task in the sequence. C) The change in accuracy after the latent update loop. Latent updates contribute increasingly more to recovering accuracy at task transitions. Results averaged over 20 random seeds.
  • Figure 5: (A) Representational drift in the network trained on 10 tasks, measured by training a logistic regression classifier to predict task identity from the latent z embeddings, either from early (first 1000 batches) or late batches and tested accuracy on all. (B) Drift measured in the responses of a visual cortex neuron in response to natural movie (MOV) or passive drifting gratings (PDG). Adapted from marks_stimulus-dependent_2021
  • ...and 12 more figures