Table of Contents
Fetching ...

Fast weight programming and linear transformers: from machine learning to neurobiology

Kazuki Irie, Samuel J. Gershman

TL;DR

This Primer introduces Fast Weight Programmers (FWPs), a class of recurrent networks with 2D hidden states in which a slow network learns to modify the fast network's weights, creating a timeframe of short-term memory and a biologically plausible learning dynamic. It establishes formal connections between FWPs and transformers, including a vanilla unnormalized and a linear attention variant, and surveys a spectrum of update rules (e.g., DeltaNet) that modulate memory and expressivity. The work explores local online learning, meta-learning, and in-context learning within FWPs, analyzes their expressive power relative to RNNs and transformers, and discusses neurobiological interpretations and potential brain-inspired implementations. Overall, FWPs emerge as a versatile framework for sequence processing that combines parallel trainability with flexible, context-dependent weight modulation, offering a bridge between artificial sequence models and synaptic plasticity in the brain.

Abstract

Recent advances in artificial neural networks for machine learning, and language modeling in particular, have established a family of recurrent neural network (RNN) architectures that, unlike conventional RNNs with vector-form hidden states, use two-dimensional (2D) matrix-form hidden states. Such 2D-state RNNs, known as Fast Weight Programmers (FWPs), can be interpreted as a neural network whose synaptic weights (called fast weights) dynamically change over time as a function of input observations, and serve as short-term memory storage; corresponding synaptic weight modifications are controlled or programmed by another network (the programmer) whose parameters are trained (e.g., by gradient descent). In this Primer, we review the technical foundations of FWPs, their computational characteristics, and their connections to transformers and state space models. We also discuss connections between FWPs and models of synaptic plasticity in the brain, suggesting a convergence of natural and artificial intelligence.

Fast weight programming and linear transformers: from machine learning to neurobiology

TL;DR

This Primer introduces Fast Weight Programmers (FWPs), a class of recurrent networks with 2D hidden states in which a slow network learns to modify the fast network's weights, creating a timeframe of short-term memory and a biologically plausible learning dynamic. It establishes formal connections between FWPs and transformers, including a vanilla unnormalized and a linear attention variant, and surveys a spectrum of update rules (e.g., DeltaNet) that modulate memory and expressivity. The work explores local online learning, meta-learning, and in-context learning within FWPs, analyzes their expressive power relative to RNNs and transformers, and discusses neurobiological interpretations and potential brain-inspired implementations. Overall, FWPs emerge as a versatile framework for sequence processing that combines parallel trainability with flexible, context-dependent weight modulation, offering a bridge between artificial sequence models and synaptic plasticity in the brain.

Abstract

Recent advances in artificial neural networks for machine learning, and language modeling in particular, have established a family of recurrent neural network (RNN) architectures that, unlike conventional RNNs with vector-form hidden states, use two-dimensional (2D) matrix-form hidden states. Such 2D-state RNNs, known as Fast Weight Programmers (FWPs), can be interpreted as a neural network whose synaptic weights (called fast weights) dynamically change over time as a function of input observations, and serve as short-term memory storage; corresponding synaptic weight modifications are controlled or programmed by another network (the programmer) whose parameters are trained (e.g., by gradient descent). In this Primer, we review the technical foundations of FWPs, their computational characteristics, and their connections to transformers and state space models. We also discuss connections between FWPs and models of synaptic plasticity in the brain, suggesting a convergence of natural and artificial intelligence.

Paper Structure

This paper contains 31 sections, 43 equations, 2 figures, 2 tables.

Figures (2)

  • Figure 1: An Illustration of sequence processing in a: conventional recurrent neural networks (RNNs), b: fast weight programmers (FWPs), and c: transformers. One time step of recurrent computation is shown. In all figures, colored circles indicate time-step specific variables (green) that are not retained for the next time step, temporally changing model state/short-term memory (yellow), and model parameters that are fixed/frozen after training (blue). In particular, the hidden state ${\bm{W}}_t$ of an FWP is a context-dependent time-varying matrix, whereas the hidden state of a conventional RNN ${\bm{s}}_t$ is a vector, and that of a transformer is the key-value memory matrices ${\bm{K}}_t$ and ${\bm{V}}_t$ whose size grows linearly with the sequence length. In b, black arrows indicate computation in the fast/main net of the FWP, while the remaining gray arrows correspond to the computation performed by the slow/programmer net. Activation functions and variables that are specific to certain variants of FWPs (such as a dynamic learning rate or state decay factor) are omitted for clarify; see Table \ref{['tab:model_variations']} for a specific choice of the update rule used in various models.
  • Figure 2: An Illustration constrasting a: a conventional view on sequence model with a learning algorithm, and b: a metalearned (or in-context learning) system that embeds learning algorithms/dynamics within its sequential dynamics. In a, the sequence model only observes an input, and produces an output (black), while the learning algorithm receives the expected target and the model output, and takes care of adjusting the parameters of the sequence model to improve upon the given task (gray). In contrast, in b, the system itself observes the input and the (delayed) expected target, and self-improvement on the task, i.e., learning, is part of its sequential dynamics.