Table of Contents
Fetching ...

Gated recurrent neural networks discover attention

Nicolas Zucchet, Seijin Kobayashi, Yassir Akram, Johannes von Oswald, Maxime Larcher, Angelika Steger, João Sacramento

TL;DR

This paper demonstrates that gated recurrent neural networks with linear diagonal recurrence and multiplicative gates can exactly implement linear self-attention, revealing a structural bridge between RNNs and Transformer-style attention. Through a formal constructive argument, it shows how to store and process past inputs to reproduce the attention operation with a finite, if large, number of neurons, and it analyzes parameter efficiency and invariances. Empirically, it shows that trained gated RNNs can learn attention-like solutions (teacher-student experiments) and, in in-context learning tasks, discover gradient-descent-like algorithms akin to those used by linear self-attention. The findings suggest attention-like computation can be encoded inside RNNs, offering insights for architecture design, potential compressions, and connections to neuroscience, while also recognizing practical limits due to parameter counts and nonlinearity effects.

Abstract

Recent architectural developments have enabled recurrent neural networks (RNNs) to reach and even surpass the performance of Transformers on certain sequence modeling tasks. These modern RNNs feature a prominent design pattern: linear recurrent layers interconnected by feedforward paths with multiplicative gating. Here, we show how RNNs equipped with these two design elements can exactly implement (linear) self-attention, the main building block of Transformers. By reverse-engineering a set of trained RNNs, we find that gradient descent in practice discovers our construction. In particular, we examine RNNs trained to solve simple in-context learning tasks on which Transformers are known to excel and find that gradient descent instills in our RNNs the same attention-based in-context learning algorithm used by Transformers. Our findings highlight the importance of multiplicative interactions in neural networks and suggest that certain RNNs might be unexpectedly implementing attention under the hood.

Gated recurrent neural networks discover attention

TL;DR

This paper demonstrates that gated recurrent neural networks with linear diagonal recurrence and multiplicative gates can exactly implement linear self-attention, revealing a structural bridge between RNNs and Transformer-style attention. Through a formal constructive argument, it shows how to store and process past inputs to reproduce the attention operation with a finite, if large, number of neurons, and it analyzes parameter efficiency and invariances. Empirically, it shows that trained gated RNNs can learn attention-like solutions (teacher-student experiments) and, in in-context learning tasks, discover gradient-descent-like algorithms akin to those used by linear self-attention. The findings suggest attention-like computation can be encoded inside RNNs, offering insights for architecture design, potential compressions, and connections to neuroscience, while also recognizing practical limits due to parameter counts and nonlinearity effects.

Abstract

Recent architectural developments have enabled recurrent neural networks (RNNs) to reach and even surpass the performance of Transformers on certain sequence modeling tasks. These modern RNNs feature a prominent design pattern: linear recurrent layers interconnected by feedforward paths with multiplicative gating. Here, we show how RNNs equipped with these two design elements can exactly implement (linear) self-attention, the main building block of Transformers. By reverse-engineering a set of trained RNNs, we find that gradient descent in practice discovers our construction. In particular, we examine RNNs trained to solve simple in-context learning tasks on which Transformers are known to excel and find that gradient descent instills in our RNNs the same attention-based in-context learning algorithm used by Transformers. Our findings highlight the importance of multiplicative interactions in neural networks and suggest that certain RNNs might be unexpectedly implementing attention under the hood.
Paper Structure (45 sections, 40 equations, 7 figures, 4 tables)

This paper contains 45 sections, 40 equations, 7 figures, 4 tables.

Figures (7)

  • Figure 1: An example of a diagonal linear gated recurrent neural network that implements the same function as a linear self-attention layer with parameters $(W_V, W_K, W_Q)$ and input dimension $d$, as described in Section \ref{['sec:construction']}. Inputs are processed from top to the bottom. We do not use biases so we append 1 to the input vector $x_t$ to be able to send queries to the recurrent neurons. We use $\mathrm{repeat}(A, n)$ to denote that the matrix $A$ is repeated $n$ times on the row axis and $W_{V,i}$ is the $i$-th row of the $W_V$ matrix. The bars within the matrices separate the different kinds of inputs/outputs. Digits in matrices denote column vectors appropriately sized. The readout matrix $D$ appropriately sums the elementwise products between key-values and queries computed after the output gating $g^\mathrm{out}$. Exact matrix values can be found in Appendix \ref{['app:explicit-construction']}.
  • Figure 2: In our teacher-student experiment of Section \ref{['subsec:teacher-identification']} ($d=4$), the structure of the weights of the RNN after learning matches the one of our compact construction, c.f. Section \ref{['sec:construction']}. (A) Summary of the post-processing we apply to the trained network weights. The number of recurrent neurons is denoted $n$, and the number of neurons after the output gating is denoted $m$. (B) Only recurrent neurons with perfect memory ($\lambda=1$, dark blue) or no memory at all ($\lambda=0$, light grey) influence the output, consistently with the theory. The block structure of the different weight matrices almost perfectly match the one of our construction, c.f. Figure \ref{['fig:construction']}(C) The last three output neurons of the output gating are functionally equivalent to a single neuron whose input weights match the structure of the rest of the output gating weights. This can be achieved by representing each such neuron as an outer product (left part) which will later be combined by the readout matrix $D$. The combined kernels are rank 1 and proportional to each other. They can thus be expressed as the same outer product (right part). In all the matrices displayed here, zero entries are shown in light grey, blue denotes positive entries, and red negative ones.
  • Figure 3: Gated RNNs learn compressed representations when possible. In the teacher-student experiment of Section \ref{['sec:gatedRNN_identification']}(A, B), the gated RNN identifies the teacher function under mild overparametrization. When the attention layer weights are low rank (B) the RNN learns a more compressed representation than what it would do when they are full rank (A). (C) In the linear regression task of Section \ref{['sec:ICL']}, the gated RNN behaves similarly to the optimal linear attention layer for that task, as the difference between their losses (delta loss) goes to 0. Moreover, the RNN discovers the same low-rank structure as this attention layer.
  • Figure 4: Comparison of the test loss obtained by different gated recurrent networks architectures in (A) the teacher-student task of Section \ref{['sec:gatedRNN_identification']} and (B) the in-context linear regression task of Section \ref{['sec:ICL']}. The construction baseline corresponds to the gated RNN of Eq. \ref{['eq:gated_rnn']}, with diagonal or dense connectivity. We use the default implementation of LSTMs and GRUs, and slightly modify the LRU architecture to reflect our construction better. Non-linearity improves the in-context learning performance but deteriorates the ability to mimic attention.
  • Figure 5: Construction for gated RNNs with side gating, as described in Section \ref{['app:side_gating_construction']}
  • ...and 2 more figures