Table of Contents
Fetching ...

LoMA: Lossless Compressed Memory Attention

Yumeng Wang, Zhenyang Xiao

TL;DR

LoMA introduces a lossless memory compression scheme for transformer KV caches to extend context length with reduced compute. It uses a three-zone generation approach (reading-zone, memory-zone, repetition-zone) and a specialized attention mask to train the model to compress memory losslessly, achieving a fixed memory length $t$ after every $tc$ tokens. The method is validated by fine-tuning Llama-2-7B-Chat, showing near-zero repetition loss for $c\le8$, strong generalization from C4 to GSM8K, and notable reductions in memory usage and latency. Importantly, LoMA does not modify model architecture or rely on auxiliary models, and it can be trained with publicly available data, making long-context processing more practical for existing LMs.

Abstract

Large Language Models (LLMs) face limitations due to the high demand on GPU memory and computational resources when handling long contexts. While sparsify the Key-Value (KV) cache of transformer model is a typical strategy to alleviate resource usage, it unavoidably results in the loss of information. We introduce Lossless Compressed Memory Attention (LoMA), a novel approach that enables lossless compression of the KV cache, thereby reducing the memory and computational demands during autoregressive generation. LoMA incorporates a specialized training or fine-tuning precedure alongside an autoregressive generation algorithm optimized for the compressed context. Our method compresses the KV cache after every $tc$ generated tokens with a compression ratio of $c$ and a target compressed length $t$, and this process occurs within a single inference pass without dependency on auxiliary models. We engineered an efficient training scheme involving specific inputs, attention masks, and position identifiers to instill this compression capability. Experimental validation has demonstrated that LoMA significantly reducing computational consumption and memory usage through achieving lossless KV cache compression.

LoMA: Lossless Compressed Memory Attention

TL;DR

LoMA introduces a lossless memory compression scheme for transformer KV caches to extend context length with reduced compute. It uses a three-zone generation approach (reading-zone, memory-zone, repetition-zone) and a specialized attention mask to train the model to compress memory losslessly, achieving a fixed memory length after every tokens. The method is validated by fine-tuning Llama-2-7B-Chat, showing near-zero repetition loss for , strong generalization from C4 to GSM8K, and notable reductions in memory usage and latency. Importantly, LoMA does not modify model architecture or rely on auxiliary models, and it can be trained with publicly available data, making long-context processing more practical for existing LMs.

Abstract

Large Language Models (LLMs) face limitations due to the high demand on GPU memory and computational resources when handling long contexts. While sparsify the Key-Value (KV) cache of transformer model is a typical strategy to alleviate resource usage, it unavoidably results in the loss of information. We introduce Lossless Compressed Memory Attention (LoMA), a novel approach that enables lossless compression of the KV cache, thereby reducing the memory and computational demands during autoregressive generation. LoMA incorporates a specialized training or fine-tuning precedure alongside an autoregressive generation algorithm optimized for the compressed context. Our method compresses the KV cache after every generated tokens with a compression ratio of and a target compressed length , and this process occurs within a single inference pass without dependency on auxiliary models. We engineered an efficient training scheme involving specific inputs, attention masks, and position identifiers to instill this compression capability. Experimental validation has demonstrated that LoMA significantly reducing computational consumption and memory usage through achieving lossless KV cache compression.
Paper Structure (18 sections, 18 equations, 7 figures, 2 tables)

This paper contains 18 sections, 18 equations, 7 figures, 2 tables.

Figures (7)

  • Figure 1: Comparison of the standard transformer model with the LoMA model in autoregressive generation: (a) In the standard transformer model's autoregressive generation, the input token and the previous context's KV cache are fed together into the attention module to compute and predict the next token. (b) In the LoMA model's autoregressive generation, the previous context's KV cache is first compressed, and the input token is processed with the compressed KV cache by the attention module.
  • Figure 2: This figure delineates the relationship between single inference latency and KV cache length across various input token sequence lengths. The findings indicate that the latency of a single inference grows linearly with the length of the KV cache, yet the augmentation of input token sequence length does not substantially affect the computation time. Notably, when the input sequence consists of 16 tokens, an increase in KV cache length from 0 to 240 does not incur additional inference time, which might be attributable to the computational capacity characteristics of the hardware.
  • Figure 3: The top row represents the original training samples, while the bottom row shows the processed training samples used for training or fine-tuning the LoMA model. In the original training samples, we insert $t$ '< m>' tokens and $tc$ '< r>' tokens after every $tc$ tokens.
  • Figure 4: This figure describes the correspondence between inputs and labels. In reading zone, the input and target exhibit a standard autoregressive relationship. No labels are set in the memory zone, while the labels in the repetition zone consist of content from the reading zone. We demonstrated in Section.\ref{['Gradient']} that by backpropagating gradients through the repetition zone, a supervisory signal can be provided to the memory zone. This allows the '< m>' token to learn to compress the content of the reading zone into its own KV.
  • Figure 5: The figure presents an attention mask for an input sequence comprising 12 tokens, which includes the initial token '< s>'. In this configuration, with $t=2$ and $c=2$, the reading and repetition zones each span 4 tokens, and the recall zone encompasses 2 tokens. Accordingly, the sequence is segmented into three training chunks. Each chunk is prefixed with '< m>' and suffixed with '< r>' tokens, yielding a total chunk length of 10 tokens (4+2+4). This results in an attention mask with a dimension of $30\times30$. Within this matrix, grey squares indicate a value of $0$, which blocks attention, and blue squares represent a value of $1$, allowing attention to flow.
  • ...and 2 more figures