Table of Contents
Fetching ...

Memory Caching: RNNs with Growing Memory

Ali Behrouz, Zeman Li, Yuan Deng, Peilin Zhong, Meisam Razaviyayn, Vahab Mirrokni

TL;DR

Memory Caching is introduced, a simple yet effective technique that enhances recurrent models by caching checkpoints of their memory states (a.k.a. hidden states), and results on language modeling, and long-context understanding tasks show that MC enhances the performance of recurrent models, supporting its effectiveness.

Abstract

Transformers have been established as the de-facto backbones for most recent advances in sequence modeling, mainly due to their growing memory capacity that scales with the context length. While plausible for retrieval tasks, it causes quadratic complexity and so has motivated recent studies to explore viable subquadratic recurrent alternatives. Despite showing promising preliminary results in diverse domains, such recurrent architectures underperform Transformers in recall-intensive tasks, often attributed to their fixed-size memory. In this paper, we introduce Memory Caching (MC), a simple yet effective technique that enhances recurrent models by caching checkpoints of their memory states (a.k.a. hidden states). Memory Caching allows the effective memory capacity of RNNs to grow with sequence length, offering a flexible trade-off that interpolates between the fixed memory (i.e., $O(L)$ complexity) of RNNs and the growing memory (i.e., $O(L^2)$ complexity) of Transformers. We propose four variants of MC, including gated aggregation and sparse selective mechanisms, and discuss their implications on both linear and deep memory modules. Our experimental results on language modeling, and long-context understanding tasks show that MC enhances the performance of recurrent models, supporting its effectiveness. The results of in-context recall tasks indicate that while Transformers achieve the best accuracy, our MC variants show competitive performance, close the gap with Transformers, and performs better than state-of-the-art recurrent models.

Memory Caching: RNNs with Growing Memory

TL;DR

Memory Caching is introduced, a simple yet effective technique that enhances recurrent models by caching checkpoints of their memory states (a.k.a. hidden states), and results on language modeling, and long-context understanding tasks show that MC enhances the performance of recurrent models, supporting its effectiveness.

Abstract

Transformers have been established as the de-facto backbones for most recent advances in sequence modeling, mainly due to their growing memory capacity that scales with the context length. While plausible for retrieval tasks, it causes quadratic complexity and so has motivated recent studies to explore viable subquadratic recurrent alternatives. Despite showing promising preliminary results in diverse domains, such recurrent architectures underperform Transformers in recall-intensive tasks, often attributed to their fixed-size memory. In this paper, we introduce Memory Caching (MC), a simple yet effective technique that enhances recurrent models by caching checkpoints of their memory states (a.k.a. hidden states). Memory Caching allows the effective memory capacity of RNNs to grow with sequence length, offering a flexible trade-off that interpolates between the fixed memory (i.e., complexity) of RNNs and the growing memory (i.e., complexity) of Transformers. We propose four variants of MC, including gated aggregation and sparse selective mechanisms, and discuss their implications on both linear and deep memory modules. Our experimental results on language modeling, and long-context understanding tasks show that MC enhances the performance of recurrent models, supporting its effectiveness. The results of in-context recall tasks indicate that while Transformers achieve the best accuracy, our MC variants show competitive performance, close the gap with Transformers, and performs better than state-of-the-art recurrent models.
Paper Structure (22 sections, 24 equations, 5 figures, 6 tables)

This paper contains 22 sections, 24 equations, 5 figures, 6 tables.

Figures (5)

  • Figure 1: The Overall Memory Caching Method. Each token attends to its online memory as well as a set of cached memories from the past.
  • Figure 2: Sparse Selective Caching (SSC) of Memories. A router measures the contextual similarity of each token to its past segments and chooses a subset of past cached memory for better efficiency.
  • Figure 3: An illustrative example of memory caching with constant and logarithmic size segments. Logarithmic segmentation while computationally appealing, results in either long subsequences that might cause memory overflow, and/or short subsequences that prevents the memory to properly optimizes itself in the inner-loop.
  • Figure 4: Training throughput comparison of memory caching variants and baselines.
  • Figure 5: Average accuracy on MQAR over 5 seeds.