Table of Contents
Fetching ...

Hidden Traveling Waves bind Working Memory Variables in Recurrent Neural Networks

Arjun Karuvally, Terrence J. Sejnowski, Hava T. Siegelmann

TL;DR

The paper investigates how traveling wave dynamics can bind working memory variables in recurrent networks to store history-dependent information. It introduces Traveling Wave Memory (TWM) with two boundary conditions: Linear Boundary Condition (LBC) linking to RNNs, and Self-attention Boundary Condition (SBC) linking to transformers. Theoretical results show a linear operator $\Phi$ governing wave propagation under LBC and a non-linear self-attention boundary under SBC, with empirical evidence that trained RNNs converge to these dynamics and encode recent history as traveling waves. These findings suggest traveling waves as a unifying memory substrate that can improve gradient propagation and inform future neural architectures.

Abstract

Traveling waves are a fundamental phenomenon in the brain, playing a crucial role in short-term information storage. In this study, we leverage the concept of traveling wave dynamics within a neural lattice to formulate a theoretical model of neural working memory, study its properties, and its real world implications in AI. The proposed model diverges from traditional approaches, which assume information storage in static, register-like locations updated by interference. Instead, the model stores data as waves that is updated by the wave's boundary conditions. We rigorously examine the model's capabilities in representing and learning state histories, which are vital for learning history-dependent dynamical systems. The findings reveal that the model reliably stores external information and enhances the learning process by addressing the diminishing gradient problem. To understand the model's real-world applicability, we explore two cases: linear boundary condition (LBC) and non-linear, self-attention-driven boundary condition (SBC). The model with the linear boundary condition results in a shift matrix plus low-rank matrix currently used in H3 state space RNN. Further, our experiments with LBC reveal that this matrix is effectively learned by Recurrent Neural Networks (RNNs) through backpropagation when modeling history-dependent dynamical systems. Conversely, the SBC parallels the autoregressive loop of an attention-only transformer with the context vector representing the wave substrate. Collectively, our findings suggest the broader relevance of traveling waves in AI and its potential in advancing neural network architectures.

Hidden Traveling Waves bind Working Memory Variables in Recurrent Neural Networks

TL;DR

The paper investigates how traveling wave dynamics can bind working memory variables in recurrent networks to store history-dependent information. It introduces Traveling Wave Memory (TWM) with two boundary conditions: Linear Boundary Condition (LBC) linking to RNNs, and Self-attention Boundary Condition (SBC) linking to transformers. Theoretical results show a linear operator governing wave propagation under LBC and a non-linear self-attention boundary under SBC, with empirical evidence that trained RNNs converge to these dynamics and encode recent history as traveling waves. These findings suggest traveling waves as a unifying memory substrate that can improve gradient propagation and inform future neural architectures.

Abstract

Traveling waves are a fundamental phenomenon in the brain, playing a crucial role in short-term information storage. In this study, we leverage the concept of traveling wave dynamics within a neural lattice to formulate a theoretical model of neural working memory, study its properties, and its real world implications in AI. The proposed model diverges from traditional approaches, which assume information storage in static, register-like locations updated by interference. Instead, the model stores data as waves that is updated by the wave's boundary conditions. We rigorously examine the model's capabilities in representing and learning state histories, which are vital for learning history-dependent dynamical systems. The findings reveal that the model reliably stores external information and enhances the learning process by addressing the diminishing gradient problem. To understand the model's real-world applicability, we explore two cases: linear boundary condition (LBC) and non-linear, self-attention-driven boundary condition (SBC). The model with the linear boundary condition results in a shift matrix plus low-rank matrix currently used in H3 state space RNN. Further, our experiments with LBC reveal that this matrix is effectively learned by Recurrent Neural Networks (RNNs) through backpropagation when modeling history-dependent dynamical systems. Conversely, the SBC parallels the autoregressive loop of an attention-only transformer with the context vector representing the wave substrate. Collectively, our findings suggest the broader relevance of traveling waves in AI and its potential in advancing neural network architectures.
Paper Structure (23 sections, 1 theorem, 44 equations, 12 figures, 2 tables, 1 algorithm)

This paper contains 23 sections, 1 theorem, 44 equations, 12 figures, 2 tables, 1 algorithm.

Key Result

Theorem 3.1

Any history dependent dynamical system with a state dimension of $d$, a history of $s$ states and an evolution function $f$ can be represented in the traveling wave model.

Figures (12)

  • Figure 1: Illustration of information storage in traveling waves - A prominent hypothesis on the computational utility of traveling waves says that information is stored as ripple like waves that propagate outwards with time. A snapshot of the resulting wavefield provides all information necessary to reconstruct the recent history by encoding both when and where (in the dimensions of the stimulus) a stimulus occured.
  • Figure 2: Traveling Wave Memory Architecture - The traveling wave based working memory architecture is composed of a neural substrate $h_{i j}$ with neurons arranged in a rectangular lattice. $d$ independent waves travel from the column with index $i=s$ down to $i=1$. These waves are independent and do not interact with each other as they travel in the substrate. The end boundary of the substrate is left open so there is no interference from reflecting waves, and the start boundary condition is computing a function $f$ of the entire neural substrate at the previous time step. This simple model is found to underlie working memory storage in RNNs.
  • Figure 3: Basis transformation reveals traveling waves encoding the recent past in the repeat copy task (with $s=d=8$): A. In the repeat copy task ($\mathcal{T}_1$), the RNN needs to repeatedly produce an input sequence that is presented. A typical trained hidden state after providing the input does not show any meaningful patterns connected to the input. B. The same hidden states when their basis is transformed reveal the input information being stored as waves of activity traveling from the variable with index $8$ down to the variable with index $1$ that are repeatedly mutated with the boundary condition.
  • Figure 4: TWM improve human interpretation of the RNN parameters (with $s=d=8$): The learned weights when visualized in the LBC basis results in a form that is human-interpretable. For RNNs trained on two sample tasks $\mathcal{T}_1$ (A left) and $\mathcal{T}_2$ (B right), the weight matrix $W_{hh}$ converts into a form that reveals internal mechanisms of how RNNs solve the two tasks. For both tasks, the variables with index $<8$ copies its contents to the preceding variable resulting in a wave of activity. Variable $8$ actively computes the function $f$ applied on all the variables stored in the hidden state as boundary condition. For $\mathcal{T}_1$, the boundary condition is a simple copy of the $1^{\text{st}}$ variable, and for $\mathcal{T}_2$, it is a linear composition of all the variables Notably, the circuit for $\mathcal{T}_2$ shows an optimized basis where the wave for each dimension travels only till the boundary that is necessary to be stored for computation.
  • Figure 5: Analysis of the RNN gradient propagation behavior reveals the diminishing gradient problem alleviating during training, as predicted by the TWM: The experiment on the trends in the gradient norm with respect to RNN inputs show that the diminishing gradient issue reduces during training A. During early training iterations, the gradient norm decays exponentially the farther it is propagated (denoted by lower input indices). This diminishing behavior is slowly alleviated as training progresses, with later training iterations gradient norm approaching $1$ - the ideal gradient norm required for preventing the diminishing/exploding gradient issue. B The analysis of the decay rates (shown by the red line) during training show sharp decrease during training, after which the decay rate remains close to 0. The transition to the decay rate of 0 happens when the absolute value of the maximum eigenvalue (shown by the blue line) crosses 1. At this point, the eigenvalue crosses the unit circle in the imaginary plane and traveling waves are set up. Taken together, the two plots verify the predictions of the TWM in relation to the gradient propagation behavior of RNNs.
  • ...and 7 more figures

Theorems & Definitions (1)

  • Theorem 3.1