Table of Contents
Fetching ...

Discovering the Gems in Early Layers: Accelerating Long-Context LLMs with 1000x Input Token Reduction

Zhenmei Shi, Yifei Ming, Xuan-Phi Nguyen, Yingyu Liang, Shafiq Joty

TL;DR

This research introduces a novel approach for the long context bottleneck to accelerate LLM inference and reduce GPU memory consumption, and proposes an algorithm that uses early layers of an LLM as filters to select and compress input tokens, significantly reducing the context length for subsequent processing.

Abstract

Large Language Models (LLMs) have demonstrated remarkable capabilities in handling long context inputs, but this comes at the cost of increased computational resources and latency. Our research introduces a novel approach for the long context bottleneck to accelerate LLM inference and reduce GPU memory consumption. Our research demonstrates that LLMs can identify relevant tokens in the early layers before generating answers to a query. Leveraging this insight, we propose an algorithm that uses early layers of an LLM as filters to select and compress input tokens, significantly reducing the context length for subsequent processing. Our method, GemFilter, demonstrates substantial improvements in both speed and memory efficiency compared to existing techniques, such as standard attention and SnapKV/H2O. Notably, it achieves a 2.4$\times$ speedup and 30\% reduction in GPU memory usage compared to SOTA methods. Evaluation on the Needle in a Haystack task shows that GemFilter significantly outperforms standard attention, SnapKV and demonstrates comparable performance on the LongBench challenge. GemFilter is simple, training-free, and broadly applicable across different LLMs. Crucially, it provides interpretability by allowing humans to inspect the selected input sequence. These findings not only offer practical benefits for LLM deployment, but also enhance our understanding of LLM internal mechanisms, paving the way for further optimizations in LLM design and inference. Our code is available at \url{https://github.com/SalesforceAIResearch/GemFilter}.

Discovering the Gems in Early Layers: Accelerating Long-Context LLMs with 1000x Input Token Reduction

TL;DR

This research introduces a novel approach for the long context bottleneck to accelerate LLM inference and reduce GPU memory consumption, and proposes an algorithm that uses early layers of an LLM as filters to select and compress input tokens, significantly reducing the context length for subsequent processing.

Abstract

Large Language Models (LLMs) have demonstrated remarkable capabilities in handling long context inputs, but this comes at the cost of increased computational resources and latency. Our research introduces a novel approach for the long context bottleneck to accelerate LLM inference and reduce GPU memory consumption. Our research demonstrates that LLMs can identify relevant tokens in the early layers before generating answers to a query. Leveraging this insight, we propose an algorithm that uses early layers of an LLM as filters to select and compress input tokens, significantly reducing the context length for subsequent processing. Our method, GemFilter, demonstrates substantial improvements in both speed and memory efficiency compared to existing techniques, such as standard attention and SnapKV/H2O. Notably, it achieves a 2.4 speedup and 30\% reduction in GPU memory usage compared to SOTA methods. Evaluation on the Needle in a Haystack task shows that GemFilter significantly outperforms standard attention, SnapKV and demonstrates comparable performance on the LongBench challenge. GemFilter is simple, training-free, and broadly applicable across different LLMs. Crucially, it provides interpretability by allowing humans to inspect the selected input sequence. These findings not only offer practical benefits for LLM deployment, but also enhance our understanding of LLM internal mechanisms, paving the way for further optimizations in LLM design and inference. Our code is available at \url{https://github.com/SalesforceAIResearch/GemFilter}.
Paper Structure (24 sections, 2 theorems, 7 equations, 8 figures, 2 tables, 1 algorithm)

This paper contains 24 sections, 2 theorems, 7 equations, 8 figures, 2 tables, 1 algorithm.

Key Result

Theorem 3.3

Let $n$ be the input sequence (prompt) length and $d$ the hidden feature dimensions. In our Algorithm alg:select_gen, GemFilter uses the $r$-th layer as a filter to select $k$ input tokens. Let SnapKV and H2O also use $k$ as their cache size. Assume the LLM has $m$ attention layers, each with $h$ at

Figures (8)

  • Figure 1: Illustration of our method GemFilter: generation with context selection based on early filter layers. We demonstrate a real Needle in a Haystack task (Section \ref{['sec:sub:needle']}). The original input consists of 108,172 tokens, including the initial instruction, key message, and the query. In the first step, we use the 13th layer of the LLM (LLaMA 3.1 8B Instruct) as a filter to compress the input tokens by choosing the top $k$ indices from the last row of the attention matrix. Notably, the selected input retains the initial instruction, key message, and query. GemFilter achieves a 1000$\times$ compression, reducing the input token length to 100. In the second step, we feed the selected tokens for full LLM inference using a standard generation function, which produces the correct output. GemFilter significantly reduces running time and GPU memory with negligible performance loss.
  • Figure 2: The last row of attention matrices in early layers can locate answer-related tokens.
  • Figure 3: Comparison of time and GPU memory usage across different methods on LLaMA 3.1 8B Instruct. 'gemfilter' represents our method, using the 13th layer as the filter. It achieves a 2.4$\times$ speedup and reduces GPU memory usage by 30% compared to SnapKV. Additional results can be found in Section \ref{['sec:sub:complexity_exp']}.
  • Figure 4: Needle in a Haystack performance comparison of different methods using the Mistral Nemo 12B Instruct model (left column) and the LLaMA 3.1 8B Instruct model (right column). Results for the Phi 3.5 Mini 3.8B Instruct model are provided in Appendix \ref{['app:sub:needle']}. The $x$-axis represents the length of the input tokens, while the $y$-axis shows the position depth percentage of the 'needle' information (e.g., 0% indicates the beginning, and 100% indicates the end). A higher score reflects better performance, meaning more effective retrieval of the 'needle' information. GemFilter significantly outperforms both standard attention (full KV cache) and SnapKV.
  • Figure 5: Distance between the needle position and selected token index position across three LLMs. The position depth percentage of the "needle" information is 50%. The $x$-axis means the layer index of different LLMs. The $y$-axis means $\min($topk_index $-$ niddle_index$)$. When $y=0$, it means the needle information is covered by the selected token. The needle information has been successfully discovered in the early layers of all three LLMs.
  • ...and 3 more figures

Theorems & Definitions (8)

  • Definition 3.1: Single layer self-attention
  • Definition 3.2: Multi-layer transformer
  • Theorem 3.3: Complexity analysis
  • Definition A.1: Input embedding function and input tokens
  • Definition A.2: Output embedding function
  • Definition A.3: Softmax
  • Theorem B.1: Complexity analysis. Restatement of Theorem \ref{['thm:complexity_informal']}
  • proof : Proof of Theorem \ref{['thm:complexity_informal']}