Table of Contents
Fetching ...

M$^2$RNN: Non-Linear RNNs with Matrix-Valued States for Scalable Language Modeling

Mayank Mishra, Shawn Tan, Ion Stoica, Joseph Gonzalez, Tri Dao

Abstract

Transformers are highly parallel but are limited to computations in the TC$^0$ complexity class, excluding tasks such as entity tracking and code execution that provably require greater expressive power. Motivated by this limitation, we revisit non-linear Recurrent Neural Networks (RNNs) for language modeling and introduce Matrix-to-Matrix RNN (M$^2$RNN): an architecture with matrix-valued hidden states and expressive non-linear state transitions. We demonstrate that the language modeling performance of non-linear RNNs is limited by their state size. We also demonstrate how the state size expansion mechanism enables efficient use of tensor cores. Empirically, M$^2$RNN achieves perfect state tracking generalization at sequence lengths not seen during training. These benefits also translate to large-scale language modeling. In hybrid settings that interleave recurrent layers with attention, Hybrid M$^2$RNN outperforms equivalent Gated DeltaNet hybrids by $0.4$-$0.5$ perplexity points on a 7B MoE model, while using $3\times$ smaller state sizes for the recurrent layers. Notably, replacing even a single recurrent layer with M$^2$RNN in an existing hybrid architecture yields accuracy gains comparable to Hybrid M$^2$RNN with minimal impact on training throughput. Further, the Hybrid Gated DeltaNet models with a single M$^2$RNN layer also achieve superior long-context generalization, outperforming state-of-the-art hybrid linear attention architectures by up to $8$ points on LongBench. Together, these results establish non-linear RNN layers as a compelling building block for efficient and scalable language models.

M$^2$RNN: Non-Linear RNNs with Matrix-Valued States for Scalable Language Modeling

Abstract

Transformers are highly parallel but are limited to computations in the TC complexity class, excluding tasks such as entity tracking and code execution that provably require greater expressive power. Motivated by this limitation, we revisit non-linear Recurrent Neural Networks (RNNs) for language modeling and introduce Matrix-to-Matrix RNN (MRNN): an architecture with matrix-valued hidden states and expressive non-linear state transitions. We demonstrate that the language modeling performance of non-linear RNNs is limited by their state size. We also demonstrate how the state size expansion mechanism enables efficient use of tensor cores. Empirically, MRNN achieves perfect state tracking generalization at sequence lengths not seen during training. These benefits also translate to large-scale language modeling. In hybrid settings that interleave recurrent layers with attention, Hybrid MRNN outperforms equivalent Gated DeltaNet hybrids by - perplexity points on a 7B MoE model, while using smaller state sizes for the recurrent layers. Notably, replacing even a single recurrent layer with MRNN in an existing hybrid architecture yields accuracy gains comparable to Hybrid MRNN with minimal impact on training throughput. Further, the Hybrid Gated DeltaNet models with a single MRNN layer also achieve superior long-context generalization, outperforming state-of-the-art hybrid linear attention architectures by up to points on LongBench. Together, these results establish non-linear RNN layers as a compelling building block for efficient and scalable language models.
Paper Structure (62 sections, 1 theorem, 13 equations, 8 figures, 8 tables, 2 algorithms)

This paper contains 62 sections, 1 theorem, 13 equations, 8 figures, 8 tables, 2 algorithms.

Key Result

Theorem 1

The M$\mathrm{^2}$RNN recurrence can represent all tasks representable by non-linear vector-valued RNNs and hence can represent regular languages.

Figures (8)

  • Figure 1: Forget gate ($f_t$) behavior as a function of the input ($x_t$) for different values of $\alpha_n$ and $\beta_n$. $n$ denotes the $n^\textrm{th}$ head.
  • Figure 2: Visualization of the Matrix-to-Matrix RNN layer. This block replaces attention and is combined with MLP and RMSNorm modules as in the Transformer.
  • Figure 3: Accuracy on the permutation group $S_3$ state-tracking task for M$\mathrm{^2}$RNN, Gated DeltaNet $[-1, 1]$, Gated DeltaProduct $[-1, 1]$ and GRU. The vertical line at $128$ denotes the training context length for all models. For the Gated DeltaProduct, we use product of 2 HouseHolder matrices since it can solve the $S_3$ task in theory grazzi2024unlocking. However, as the evaluation context length is increased beyond the training context length, we observe accuracy degradation.
  • Figure 4: TP topology-aware M$\mathrm{^2}$RNN layer running with TP on 2 GPUs. Note that RMSNorm$_1$ and RMSNorm$_2$ have different weights for the RMSNorm module on both GPUs.
  • Figure 5: TP topology-independent M$\mathrm{^2}$RNN layer. Note that RMSNormTP have weights sharded along the model width dimension for the RMSNorm module on both GPUs requiring a synchronization in both the forward and backward computation.
  • ...and 3 more figures

Theorems & Definitions (2)

  • Theorem 1
  • proof