Table of Contents
Fetching ...

Hybrid Associative Memories

Leon Lufkin, Tomás Figliolia, Beren Millidge, Kamesh Krishnamurthy

Abstract

Recurrent neural networks (RNNs) and self-attention are both widely used sequence-mixing layers that maintain an internal memory. However, this memory is constructed using two orthogonal mechanisms: RNNs compress the entire past into a fixed-size state, whereas self-attention's state stores every past time step growing its state (the KV cache) linearly with the sequence length. This results in orthogonal strengths and weaknesses. Self-attention layers excel at retrieving information in the context but have large memory and computational costs, while RNNs are more efficient but degrade over longer contexts and underperform for precise recall tasks. Prior work combining these mechanisms has focused primarily on naively interleaving them to reduce computational cost without regard to their complementary mechanisms. We propose the Hybrid Associative Memory (HAM) layer, which combines self-attention and RNNs while leveraging their individual strengths: the RNN compresses the entire sequence, while attention supplements it *only* with information that is difficult for the RNN to predict, which is hence the most valuable information to explicitly store. HAM layers enable data-dependent growth of the KV cache, which can be precisely controlled by the user with a single, continuous threshold. We find that this fine-grained control of the KV cache growth rate has a smooth trade-off with loss and performance. Empirically, we show that our hybrid architecture offers strong, competitive performance relative to RNNs and Transformers even at substantially lower KV-cache usage.

Hybrid Associative Memories

Abstract

Recurrent neural networks (RNNs) and self-attention are both widely used sequence-mixing layers that maintain an internal memory. However, this memory is constructed using two orthogonal mechanisms: RNNs compress the entire past into a fixed-size state, whereas self-attention's state stores every past time step growing its state (the KV cache) linearly with the sequence length. This results in orthogonal strengths and weaknesses. Self-attention layers excel at retrieving information in the context but have large memory and computational costs, while RNNs are more efficient but degrade over longer contexts and underperform for precise recall tasks. Prior work combining these mechanisms has focused primarily on naively interleaving them to reduce computational cost without regard to their complementary mechanisms. We propose the Hybrid Associative Memory (HAM) layer, which combines self-attention and RNNs while leveraging their individual strengths: the RNN compresses the entire sequence, while attention supplements it *only* with information that is difficult for the RNN to predict, which is hence the most valuable information to explicitly store. HAM layers enable data-dependent growth of the KV cache, which can be precisely controlled by the user with a single, continuous threshold. We find that this fine-grained control of the KV cache growth rate has a smooth trade-off with loss and performance. Empirically, we show that our hybrid architecture offers strong, competitive performance relative to RNNs and Transformers even at substantially lower KV-cache usage.
Paper Structure (26 sections, 11 equations, 6 figures, 2 tables)

This paper contains 26 sections, 11 equations, 6 figures, 2 tables.

Figures (6)

  • Figure 1: Hybrid Associative Memory that maintains both an RNN state $S_t$ and KV cache which stores surprising tokens. The two memories work together in a complementary fashion.
  • Figure 2: Language modeling loss(top) and long-context accuracy on RULER (bottom) as a function of the KV-cache usage in HAM. The loss and the accuracy show a smooth relation with the amount of KV-cache HAM uses -- a user-controlled parameter. Notably, the HAM with the learned router with $50\%$ KV-cache usage shows strong performance significantly outperforming the GDN hybrid with full self-attention layers and identical KV-cache usage
  • Figure 3: Trajectory of the prediction error for the toy NIAH single 1 sequence. Obtained from the 10th layer of an 800M HAM model using a cosine-similiarity routing score trained on 50B tokens from Long Data Collections.
  • Figure 4: Routing scores $e_t$ averaged over 1,000 sequences from Institutional Books for selected layers of a HAM model with 0.5 KV usage. The average scores show a consistent downward trend along the sequence indicating that HAM is picking up long-range structure in the sequences and using its context to make better predictions towards the end of a sequence. The massive deviations around the mean trend indicate the large degree of sequence-to-sequence variability in the routing scores.
  • Figure 5: Routing scores and gate ($\alpha_t$) for examples sequences and layers. The vertical lines indicates time-points where $\alpha_t < 0.05$ i.e the state $S_t$ is effectively reset. In addition to the heterogeneity of the routing scores, the plot illustrates how frequently GDN resets the RNN state along the sequence.
  • ...and 1 more figures