Table of Contents
Fetching ...

MELODI: Exploring Memory Compression for Long Contexts

Yinpeng Chen, DeLesley Hutchins, Aren Jansen, Andrey Zhmoginov, David Racz, Jesper Andersen

TL;DR

MELODI, a novel memory architecture designed to efficiently process long documents using short context windows, demonstrates superior performance on various long-context datasets while remarkably reducing the memory footprint by a factor of 8.

Abstract

We present MELODI, a novel memory architecture designed to efficiently process long documents using short context windows. The key principle behind MELODI is to represent short-term and long-term memory as a hierarchical compression scheme across both network layers and context windows. Specifically, the short-term memory is achieved through recurrent compression of context windows across multiple layers, ensuring smooth transitions between windows. In contrast, the long-term memory performs further compression within a single middle layer and aggregates information across context windows, effectively consolidating crucial information from the entire history. Compared to a strong baseline - the Memorizing Transformer employing dense attention over a large long-term memory (64K key-value pairs) - our method demonstrates superior performance on various long-context datasets while remarkably reducing the memory footprint by a factor of 8.

MELODI: Exploring Memory Compression for Long Contexts

TL;DR

MELODI, a novel memory architecture designed to efficiently process long documents using short context windows, demonstrates superior performance on various long-context datasets while remarkably reducing the memory footprint by a factor of 8.

Abstract

We present MELODI, a novel memory architecture designed to efficiently process long documents using short context windows. The key principle behind MELODI is to represent short-term and long-term memory as a hierarchical compression scheme across both network layers and context windows. Specifically, the short-term memory is achieved through recurrent compression of context windows across multiple layers, ensuring smooth transitions between windows. In contrast, the long-term memory performs further compression within a single middle layer and aggregates information across context windows, effectively consolidating crucial information from the entire history. Compared to a strong baseline - the Memorizing Transformer employing dense attention over a large long-term memory (64K key-value pairs) - our method demonstrates superior performance on various long-context datasets while remarkably reducing the memory footprint by a factor of 8.
Paper Structure (16 sections, 2 equations, 8 figures, 9 tables)

This paper contains 16 sections, 2 equations, 8 figures, 9 tables.

Figures (8)

  • Figure 1: Overview of Melodi.Melodi employs a hierarchical memory representation, incorporating both short-term and long-term compression mechanisms, integrated with a transformer-based language model. It utilizes a stack of short-term layers to recurrently compress each context window $x_k^0$ into short-term memory tokens $\{z^l_k\}$, and inserts a long-term layer to store compressed key-value pairs within a long-term memory $m_{1:k}$. Both short-term and long-term layers leverage modified transformer blocks. In this illustration, we assume a total of $N$ layers, with $M$ short-term layers preceding 1 long-term layer and $N-M-1$ short-term layers following it.
  • Figure 2: Short-term layer. The figure illustrates the processing of the $k^{th}$ context window at the $l^{th}$ short-term layer. It takes the memory from the previous window $z_{k-1}^l$ and the current context/summary ($x_{k}^{l-1}$, $u_{k}^{l-1}$) from the previous layer as input. The short-term layer adds two linear token mixers tolstikhin2021mlpmixer on top of a standard transformer layer (including attention and FFN) to separate the summary for the next layer $u_{k}^{l}$ and the memory for the next window $z_{k}^l$. Best viewed in color.
  • Figure 3: Long-term layer. The long-term layer adds three components on top of the short-term layer (see Figure \ref{['fig:short-term']}). Firstly, it introduces a long-term memory $m_{1:k-1}$ by caching the compressed key-value pairs and allows the current context/summary ($x_{k}^{l-1}$, $u_{k}^{l-1}$) to cross attend to them. Secondly, the self-attention and cross-attention are integrated via gating. Finally, the linear token mixing output additional compressed tokens and appends their key-value pairs $m_k$ into the long-term memory (as $m_{1:k}$) for the next window. Best viewed in color.
  • Figure 4: Ablation of memory size on PG-19. The token perplexity is reported for various combinations of short-term and long-term memory sizes. Each curve represents a fixed size of long-term memory, with points along the curve indicating different short-term memory sizes. For example, the blue curve ($L_{96}$), uses 96 long-term tokens per window over 128 windows, totaling 12,288 tokens. Each point on this curve represents a different short-term memory size (e.g. $S_8$ denotes 8 short-term tokens per context window). Memory size is measured by the number of floating-point numbers (floats). For instance, $L_{96}$ stores 12,288 long-term key-value (KV) pairs, each with 1024 dimensions, resulting in a total of 12,288$\times$1024$\times$2=25.2 million floats. The table on the right provides the perplexity results for each point on the left plot, using matching colors. These results highlight that long-term and short-term memories play complementary roles, and increasing either type's capacity improves performance. Notably, Melodi achieves superior performance compared to baselines like Transformer XL, Block Recurrrent Transformer and Memorizing Transformer while utilizing fewer memory resources. Best viewed in color.
  • Figure 5: Long-term memory coverage. The coverage metric indicates the number of preceding context windows spanned by the long-term memory. All data points use $L$=64 long-term and $S$=128 short-term tokens per window. However, they vary in long-term memory capacity. For instance, '32' denotes covering 32 context windows, totaling 32$\times$64=2048 long-term tokens. Perplexity improves marginally with long-term memory coverage of 2-4 windows, then accelerates until 32 windows, after which it levels off.
  • ...and 3 more figures