Table of Contents
Fetching ...

Weight-Space Linear Recurrent Neural Networks

Roussel Desmond Nzoyem, Nawid Keshtmand, Enrique Crespo Fernandez, Idriss Tsayem, Raul Santos-Rodriguez, David A. W. Barton, Tom Deakin

TL;DR

This work introduces WARP (Weight-space Adaptive Recurrent Prediction), a simple yet powerful model that unifies weight-space learning with linear recurrence to redefine sequence modeling and solidify weight-space linear RNNs as a transformative paradigm for adaptive machine intelligence.

Abstract

We introduce WARP (Weight-space Adaptive Recurrent Prediction), a simple yet powerful model that unifies weight-space learning with linear recurrence to redefine sequence modeling. Unlike conventional recurrent neural networks (RNNs) which collapse temporal dynamics into fixed-dimensional hidden states, WARP explicitly parametrizes its hidden state as the weights and biases of a distinct auxiliary neural network, and uses input differences to drive its recurrence. This brain-inspired formulation enables efficient gradient-free adaptation of the auxiliary network at test-time, in-context learning abilities, and seamless integration of domain-specific physical priors. Empirical validation shows that WARP matches or surpasses state-of-the-art baselines on diverse classification tasks, featuring in the top three in 4 out of 6 real-world challenging datasets. Furthermore, extensive experiments across sequential image completion, multivariate time series forecasting, and dynamical system reconstruction demonstrate its expressiveness and generalisation capabilities. Remarkably, a physics-informed variant of our model outperforms the next best model by more than 10x. Ablation studies confirm the architectural necessity of key components, solidifying weight-space linear RNNs as a transformative paradigm for adaptive machine intelligence.

Weight-Space Linear Recurrent Neural Networks

TL;DR

This work introduces WARP (Weight-space Adaptive Recurrent Prediction), a simple yet powerful model that unifies weight-space learning with linear recurrence to redefine sequence modeling and solidify weight-space linear RNNs as a transformative paradigm for adaptive machine intelligence.

Abstract

We introduce WARP (Weight-space Adaptive Recurrent Prediction), a simple yet powerful model that unifies weight-space learning with linear recurrence to redefine sequence modeling. Unlike conventional recurrent neural networks (RNNs) which collapse temporal dynamics into fixed-dimensional hidden states, WARP explicitly parametrizes its hidden state as the weights and biases of a distinct auxiliary neural network, and uses input differences to drive its recurrence. This brain-inspired formulation enables efficient gradient-free adaptation of the auxiliary network at test-time, in-context learning abilities, and seamless integration of domain-specific physical priors. Empirical validation shows that WARP matches or surpasses state-of-the-art baselines on diverse classification tasks, featuring in the top three in 4 out of 6 real-world challenging datasets. Furthermore, extensive experiments across sequential image completion, multivariate time series forecasting, and dynamical system reconstruction demonstrate its expressiveness and generalisation capabilities. Remarkably, a physics-informed variant of our model outperforms the next best model by more than 10x. Ablation studies confirm the architectural necessity of key components, solidifying weight-space linear RNNs as a transformative paradigm for adaptive machine intelligence.

Paper Structure

This paper contains 85 sections, 2 theorems, 16 equations, 23 figures, 16 tables, 2 algorithms.

Key Result

Theorem 1

Assume $B \in \mathbb{R}^{D_{\theta} \times D_x}$ is a full row-rank matrix. There exists $\Delta \mathbf{x}_{0} \in \mathbb{R}^{D_x}$ and a length-$T$ kernel $K$ such that $\theta_{0:T} = K \star \Delta \mathbf{x}_{0:T}$.

Figures (23)

  • Figure 1: Background and conceptual comparison between RNN architectures. Standard RNNs (e.g. hochreiter1997longcho2014learning) feature a non-linear transition function $f_{\Phi}$ unlike their linear counterparts (e.g. gu2021efficientlyorvieto2023resurrecting). Our proposed weight-space linear RNNs view their hidden state --- denoted as $\theta_t$ --- as the parameters of a family of functions. As observed in the bottom-right corner, $\theta_t$ represents, in the general case, the flattened weights of an MLP at time step $t$. Its input $\tau$ is a (concatenation of) coordinate system(s) to maximally make use of the canonical ordering of the sequence.
  • Figure 2: (Left) General sequence modelling setting. In the forecasting scenario, for instance, a context of length $L$ informs the prediction of future states. (Right) WARP's unfolded recurrence. The initial hypernetwork $\phi$ and transition matrices $(A,B)$ --- highlighted in orange --- are learnable parameters, fitted via conventional gradient descent.
  • Figure 3: (a) Comparison of a GRU cho2014learning, LSTM hochreiter1997long, S4 alexander2022theannotateds4, and WARP on the MNIST image completion task with $L=300$ initial pixels. All models are roughly at the same size of 1.7M parameters, with architectures described in \ref{['subsec:baselines']}. The leftmost column represents target images with context (in white) and ground truths (in green). Predicted forecasts are drawn in red. (b) Heatmap of test MSEs ($\downarrow$) on the ETT task, with best results enclosed and second-best underlined.
  • Figure 4: Sample LV input/output.
  • Figure 5: Pipeline and results for in-context learning. (a) Cumulative sum transformation and subsequent processing of the input matrix. (b) Linear mappings learned between scalar keys and values of the same sequences ($D_x=2$). (c) Ground truth vs. query point predictions ($D_x=8$).
  • ...and 18 more figures

Theorems & Definitions (4)

  • Theorem 1: Convolution Mode
  • proof
  • Theorem 2: Existence of an Initial Input Difference
  • proof