Table of Contents
Fetching ...

Revisiting Bi-Linear State Transitions in Recurrent Neural Networks

M. Reza Ebrahimi, Roland Memisevic

TL;DR

This work rethinks state tracking by treating hidden units as active processors through bilinear, input-dependent state transitions, formalized as $h^t_i = \sum_{j,k} \mathcal{W}_{ijk} x^t_k h^{t-1}_j$. It shows that such bilinear dynamics yield an input-driven state-transition operator $\mathcal{A}_{x}$ and can simulate arbitrary finite-state machines; through factorized and block-diagonal variants, it achieves a scalable hierarchy of expressiveness. Empirically, bilinear models outperform non-bilinear baselines on modular addition, random FSMs, and modular arithmetic, with rotation-based $\mathcal{R}_2$ blocks capturing commutative structure and diagonal variants revealing parity capabilities even with frozen recurrent weights. The results imply that multiplicative interactions are a powerful mechanism for state tracking, offer insights into when additive terms help or hurt, and raise questions about parameter efficiency and applicability to large-scale tasks.

Abstract

The role of hidden units in recurrent neural networks is typically seen as modeling memory, with research focusing on enhancing information retention through gating mechanisms. A less explored perspective views hidden units as active participants in the computation performed by the network, rather than passive memory stores. In this work, we revisit bilinear operations, which involve multiplicative interactions between hidden units and input embeddings. We demonstrate theoretically and empirically that they constitute a natural inductive bias for representing the evolution of hidden states in state tracking tasks. These are the simplest type of tasks that require hidden units to actively contribute to the behavior of the network. We also show that bilinear state updates form a natural hierarchy corresponding to state tracking tasks of increasing complexity, with popular linear recurrent networks such as Mamba residing at the lowest-complexity center of that hierarchy.

Revisiting Bi-Linear State Transitions in Recurrent Neural Networks

TL;DR

This work rethinks state tracking by treating hidden units as active processors through bilinear, input-dependent state transitions, formalized as . It shows that such bilinear dynamics yield an input-driven state-transition operator and can simulate arbitrary finite-state machines; through factorized and block-diagonal variants, it achieves a scalable hierarchy of expressiveness. Empirically, bilinear models outperform non-bilinear baselines on modular addition, random FSMs, and modular arithmetic, with rotation-based blocks capturing commutative structure and diagonal variants revealing parity capabilities even with frozen recurrent weights. The results imply that multiplicative interactions are a powerful mechanism for state tracking, offer insights into when additive terms help or hurt, and raise questions about parameter efficiency and applicability to large-scale tasks.

Abstract

The role of hidden units in recurrent neural networks is typically seen as modeling memory, with research focusing on enhancing information retention through gating mechanisms. A less explored perspective views hidden units as active participants in the computation performed by the network, rather than passive memory stores. In this work, we revisit bilinear operations, which involve multiplicative interactions between hidden units and input embeddings. We demonstrate theoretically and empirically that they constitute a natural inductive bias for representing the evolution of hidden states in state tracking tasks. These are the simplest type of tasks that require hidden units to actively contribute to the behavior of the network. We also show that bilinear state updates form a natural hierarchy corresponding to state tracking tasks of increasing complexity, with popular linear recurrent networks such as Mamba residing at the lowest-complexity center of that hierarchy.

Paper Structure

This paper contains 27 sections, 6 theorems, 19 equations, 4 figures, 5 tables.

Key Result

Proposition 1

The bilinear state transition model defined in Equation eqn:bilinear is capable of simulating any finite state machine $\mathcal{S} = (Q, \Sigma, \delta, q_0)$.

Figures (4)

  • Figure 1: Taxonomy of bilinear RNNs studied in this paper, along with example regular language tasks they can learn (in blue).
  • Figure 2: Data efficiency comparison between bilinear models and LSTM/RNN on state tracking tasks. All models were trained on sequences of length 10 and evaluated on length 500, with varying training set sizes. Despite their large parameter counts, bilinear models including the full variant, exhibit better data efficiency compared to LSTM.
  • Figure 3: (Left) Effect of (input-dependent) additive terms in the hidden state update rule on the OOD accuracy of modular addition task. (Right) Length generalization performance on parity with a random multiplicative RNN with and without additive terms.
  • Figure 4: Visualization of the rotation angles learned by the $\mathcal{R}_2$ block-diagonal model for each input integer in the $m=10$ modular addition task. Each subplot corresponds to a distinct 2-dimensional hidden state subspace. These subspaces are ordered based on the magnitude of the classifier weights. The "harmonic" is $\frac{\theta_1}{2\pi/m}$, where $\theta_1$ is the learned rotation angle for the integer 1.

Theorems & Definitions (9)

  • Proposition 1
  • Proposition 2
  • Proposition 3
  • Proposition 3
  • proof
  • Proposition 3
  • proof
  • Proposition 3
  • proof