Table of Contents
Fetching ...

Inference-Friendly Models With MixAttention

Shashank Rajput, Ying Sheng, Sean Owen, Vitaliy Chiley

TL;DR

This work explores the use of MixAttention, a model architecture modification closely related to a blog published by Character.AI that combines sliding window attention, where only a small subset of recent tokens is stored in the KV cache, with KV cache sharing across layers.

Abstract

The size of the key-value (KV) cache plays a critical role in determining both the maximum context length and the number of concurrent requests supported during inference in modern language models. The KV cache size grows proportionally with the number of attention heads and the tokens processed, leading to increased memory consumption and slower inference for long inputs. In this work, we explore the use of MixAttention, a model architecture modification closely related to a blog published by Character.AI. MixAttention combines sliding window attention, where only a small subset of recent tokens is stored in the KV cache, with KV cache sharing across layers. Our experiments demonstrate that MixAttention significantly reduces memory usage and improves inference speed without sacrificing model performance in both short and long-context tasks. We also explore various configurations of this architecture, identifying those that maintain quality across evaluation metrics while optimizing resource efficiency.

Inference-Friendly Models With MixAttention

TL;DR

This work explores the use of MixAttention, a model architecture modification closely related to a blog published by Character.AI that combines sliding window attention, where only a small subset of recent tokens is stored in the KV cache, with KV cache sharing across layers.

Abstract

The size of the key-value (KV) cache plays a critical role in determining both the maximum context length and the number of concurrent requests supported during inference in modern language models. The KV cache size grows proportionally with the number of attention heads and the tokens processed, leading to increased memory consumption and slower inference for long inputs. In this work, we explore the use of MixAttention, a model architecture modification closely related to a blog published by Character.AI. MixAttention combines sliding window attention, where only a small subset of recent tokens is stored in the KV cache, with KV cache sharing across layers. Our experiments demonstrate that MixAttention significantly reduces memory usage and improves inference speed without sacrificing model performance in both short and long-context tasks. We also explore various configurations of this architecture, identifying those that maintain quality across evaluation metrics while optimizing resource efficiency.
Paper Structure (25 sections, 12 figures)

This paper contains 25 sections, 12 figures.

Figures (12)

  • Figure 1: (Left) Variants of MixAttention architecture - green bars represent sliding window attention and the curved lines connecting bars represent KV cache sharing. (Right, top row) We see that MixAttention models are faster and use less memory during inference at 32K context length. (Right, bottom row) MixAttention models maintain quality - they match the standard attention model on most evals. The models are all Mixture of Experts with 2B active and 5B total parameters.
  • Figure 2: MixAttention: (Left) A standard transformer model where all layers are standard attention layers. (Middle) Inference-friendly models with MixAttention. Green bars represent sliding window attention and the lines connecting bars represent KV cache sharing. (Right) A model where all layers are sliding window attention.
  • Figure 3: KV Cache position and counts: To measure the effect of the position and count of the standard attention KV caches on MixAttention’s long context abilities, we train and evaluate the 4 models shown above.
  • Figure 4: Effect of Standard Attention Layers: (Top) Loss curves of the models when fine tuning on long context QA dataset. (Bottom) RULER evals for the models. MA and MA-EndSlide perform poorly on long context tasks whereas MA-Offset and MA-Pairs perform well. This indicates that having a standard attention KV cache which is computed in later layers is important for long context abilities. We also found that the loss on long context QA dataset correlates well with the model’s long context abilities.
  • Figure 5: Increasing KV cache sharing in sliding window layers: To measure the effect of KV cache sharing in the sliding window layers, we compared the architectures shown in the figure above.
  • ...and 7 more figures