Table of Contents
Fetching ...

Squeezed Attention: Accelerating Long Context Length LLM Inference

Coleman Hooper, Sehoon Kim, Hiva Mohammadzadeh, Monishwaran Maheswaran, Sebastian Zhao, June Paik, Michael W. Mahoney, Kurt Keutzer, Amir Gholami

TL;DR

Squeezed Attention introduces a semantic, centroid-based approach to accelerate attention over fixed-context inputs in long-context LLM inference. By offline clustering fixed-context keys into centroids and online matching queries to these centroids to retrieve only semantically relevant keys, the method reduces KV bandwidth and computation while preserving accuracy. A hierarchical centroid lookup further reduces lookup overhead, enabling logarithmic-like scaling with context length. Empirical results show substantial speedups (4x+ kernel-level and up to 8x KV-budget reductions) with minimal accuracy loss on LongBench, PreFixQA, and RULER benchmarks, demonstrating practical impact for long-context applications.

Abstract

Emerging Large Language Model (LLM) applications require long input context in order to perform complex tasks like document analysis and code generation. For these long context length applications, the length of the input prompt poses a significant challenge in terms of inference efficiency since the inference costs increase linearly with sequence length. However, for many of these applications, much of the context in the prompt is fixed across different user inputs, thereby providing the opportunity to perform offline optimizations in order to process user inputs quickly, as they are received. We propose Squeezed Attention to accelerate LLM applications where a large portion of the input context is fixed. We first leverage K-means clustering offline to group the keys for the fixed context based on semantic similarity and represent each cluster with a single centroid value. During inference, we compare query tokens from the user input with the centroids to predict which keys from the fixed context are semantically relevant, and then compute exact attention using only the important keys, thereby reducing bandwidth and computational costs. We also present a hierarchical version of our algorithm which can reduce the complexity of attention from linear to logarithmic with respect to the fixed context length. We evaluate our method on long-context benchmarks including LongBench, where it achieves a 3.1$\times$ reduction in KV budget with no noticeable accuracy loss and up to an 8$\times$ reduction with only a 0.5 point accuracy gap for the LLaMA-2-7B-32K, LWM-Text-Chat-1M, and Longchat-7B-v1.5-32K models. Futhermore, we implement kernels for centroid comparison and sparse FlashAttention with important keys, achieving more than 4$\times$ speedups during both the prefill and generation phases for long-context inference. Our code is available at https://github.com/SqueezeAILab/SqueezedAttention.

Squeezed Attention: Accelerating Long Context Length LLM Inference

TL;DR

Squeezed Attention introduces a semantic, centroid-based approach to accelerate attention over fixed-context inputs in long-context LLM inference. By offline clustering fixed-context keys into centroids and online matching queries to these centroids to retrieve only semantically relevant keys, the method reduces KV bandwidth and computation while preserving accuracy. A hierarchical centroid lookup further reduces lookup overhead, enabling logarithmic-like scaling with context length. Empirical results show substantial speedups (4x+ kernel-level and up to 8x KV-budget reductions) with minimal accuracy loss on LongBench, PreFixQA, and RULER benchmarks, demonstrating practical impact for long-context applications.

Abstract

Emerging Large Language Model (LLM) applications require long input context in order to perform complex tasks like document analysis and code generation. For these long context length applications, the length of the input prompt poses a significant challenge in terms of inference efficiency since the inference costs increase linearly with sequence length. However, for many of these applications, much of the context in the prompt is fixed across different user inputs, thereby providing the opportunity to perform offline optimizations in order to process user inputs quickly, as they are received. We propose Squeezed Attention to accelerate LLM applications where a large portion of the input context is fixed. We first leverage K-means clustering offline to group the keys for the fixed context based on semantic similarity and represent each cluster with a single centroid value. During inference, we compare query tokens from the user input with the centroids to predict which keys from the fixed context are semantically relevant, and then compute exact attention using only the important keys, thereby reducing bandwidth and computational costs. We also present a hierarchical version of our algorithm which can reduce the complexity of attention from linear to logarithmic with respect to the fixed context length. We evaluate our method on long-context benchmarks including LongBench, where it achieves a 3.1 reduction in KV budget with no noticeable accuracy loss and up to an 8 reduction with only a 0.5 point accuracy gap for the LLaMA-2-7B-32K, LWM-Text-Chat-1M, and Longchat-7B-v1.5-32K models. Futhermore, we implement kernels for centroid comparison and sparse FlashAttention with important keys, achieving more than 4 speedups during both the prefill and generation phases for long-context inference. Our code is available at https://github.com/SqueezeAILab/SqueezedAttention.

Paper Structure

This paper contains 39 sections, 3 equations, 7 figures, 10 tables.

Figures (7)

  • Figure 1: A high-level visualization of our hierarchical clustering approach. We identify important keys for the current query by first identifying which coarse-grained clusters are relevant (Level 1). We then refine this prediction using finer-grained clustering (Level 2). Finally, we identify the important keys for the current query and only compute exact attention with these keys.
  • Figure 2: Diagram outlining our approach for performing clustering offline with the fixed context. Refer to Section \ref{['sec:method-cluster']} for 1-level clustering and Section \ref{['subsec:hierarchical_lookup']} for hierarchical clustering. We apply K-means clustering to group semantically similar key tokens, assigning a single centroid to represent each cluster. In the hierarchical approach (Section \ref{['subsec:hierarchical_lookup']}, demonstrating a 2-level hierarchy for clarity), these centroids form the Level 2 centroids, which are then clustered into coarser-grained Level 1 centroids by repeating the same procedure.
  • Figure 3: Diagram outlining how our method operates during inference to retrieve the most relevant keys when a new input query is received. Refer to Section \ref{['sec:method-online']} for 1-level retrieval and Section \ref{['subsec:hierarchical_lookup']} for hierarchical retrieval. For 1-level retrieval, the query token is first compared against the representative centroid of each cluster to identify the most relevant clusters. Exact attention is then computed only for the keys within these retrieved clusters, rather than across the entire fixed context. In our hierarchical retrieval approach (Section \ref{['subsec:hierarchical_lookup']}, demonstrating a 2-level hierarchy for clarity), we first compare the query with coarse-grained Level 1 centroids, and then only compare with a subset of the promising fine-grained Level 2 centroids in order to identify the important keys.
  • Figure 4: Kernel implementation latency results for FlashAttention baseline as well as for Squeezed Attention with 70%, 80%, and 90% sparsity settings. We report latency results for prefill (with 1K and 4K input length) as well as for generation with a single input token. Latency results are normalized to the FlashAttention baseline runtime for prefill (and to our Triton FlashDecoding baseline for generation) for the same input length.
  • Figure 5: t-SNE visualization of key embeddings and their Level 1 and 2 clusters from LLaMA-2-7B-32K on the TREC benchmark (two attention heads, with index 24 and 25, from layer 0). For clarity, only the top 15 Level 1 clusters nearest to the query are shown.
  • ...and 2 more figures