Table of Contents
Fetching ...

Neural Attention Search

Difan Deng, Marius Lindauer

TL;DR

NAtS tackles the costly KV-cache and long-context inference in transformers by learning per-token roles—Global, Local, and Sliding Window—via a differentiable token-type search guided by Gumbel-Softmax. It constructs a learnable attention mask that end-to-end sparsifies attention, enabling substantial KV-cache reductions while preserving model performance for both training-from-scratch and fine-tuning scenarios. Empirical results on PG-19 and LongBench show NAtS achieving strong perplexity or task performance at far smaller KV budgets, and latency/memory analyses demonstrate dramatic improvements in very long context settings. The approach offers a practical pathway to scalable, efficient long-context inference for LLMs and related transformer architectures, with potential applicability across prefill and decoding phases. $O(L^2)$ attention cost can thus be mitigated toward the regime dominated by KV-cache operations, enabling larger contexts and more capable models in real-world deployments.

Abstract

We present Neural Attention Search (NAtS), a framework that automatically evaluates the importance of each token within a sequence and determines if the corresponding token can be dropped after several steps. This approach can efficiently reduce the KV cache sizes required by transformer-based models during inference and thus reduce inference costs. In this paper, we design a search space that contains three token types: (i) Global Tokens will be preserved and queried by all the following tokens. (ii) Local Tokens survive until the next global token appears. (iii) Sliding Window Tokens have an impact on the inference of a fixed size of the next following tokens. Similar to the One-Shot Neural Architecture Search approach, this token-type information can be learned jointly with the architecture weights via a learnable attention mask. Experiments on both training a new transformer from scratch and fine-tuning existing large language models show that NAtS can efficiently reduce the KV cache size required for the models while maintaining the models' performance.

Neural Attention Search

TL;DR

NAtS tackles the costly KV-cache and long-context inference in transformers by learning per-token roles—Global, Local, and Sliding Window—via a differentiable token-type search guided by Gumbel-Softmax. It constructs a learnable attention mask that end-to-end sparsifies attention, enabling substantial KV-cache reductions while preserving model performance for both training-from-scratch and fine-tuning scenarios. Empirical results on PG-19 and LongBench show NAtS achieving strong perplexity or task performance at far smaller KV budgets, and latency/memory analyses demonstrate dramatic improvements in very long context settings. The approach offers a practical pathway to scalable, efficient long-context inference for LLMs and related transformer architectures, with potential applicability across prefill and decoding phases. attention cost can thus be mitigated toward the regime dominated by KV-cache operations, enabling larger contexts and more capable models in real-world deployments.

Abstract

We present Neural Attention Search (NAtS), a framework that automatically evaluates the importance of each token within a sequence and determines if the corresponding token can be dropped after several steps. This approach can efficiently reduce the KV cache sizes required by transformer-based models during inference and thus reduce inference costs. In this paper, we design a search space that contains three token types: (i) Global Tokens will be preserved and queried by all the following tokens. (ii) Local Tokens survive until the next global token appears. (iii) Sliding Window Tokens have an impact on the inference of a fixed size of the next following tokens. Similar to the One-Shot Neural Architecture Search approach, this token-type information can be learned jointly with the architecture weights via a learnable attention mask. Experiments on both training a new transformer from scratch and fine-tuning existing large language models show that NAtS can efficiently reduce the KV cache size required for the models while maintaining the models' performance.

Paper Structure

This paper contains 34 sections, 15 equations, 11 figures, 8 tables, 3 algorithms.

Figures (11)

  • Figure 1: A comparison between different casual attention maps. \ref{['fig:FullAtt']}: The full attention map, where each token is connected to the tokens before it. \ref{['fig:LocalAtt']}: the local attention with sliding windows 3, every token will only get access to the information of the 3 tokens ahead. \ref{['fig:LongFormer']} Longformer, besides the local attention, the first, 6th and 9th tokens are the pre-defined global tokens. \ref{['fig:LearnedAttention']}: NAtS dynamically learns the optimal role for each token and constructs a learnable mask based on the tokens' roles.
  • Figure 2: An example of how caches are updated within NAtS when new tokens arrive with a model containing two heads. The two rows represent different heads.
  • Figure 3: Perplexity vs KV Cache size under different sparsity settings $\lambda$ on the PG19 dataset.
  • Figure 4: Memory and latency usage during pre-filling (left) and decoding (right)
  • Figure 5: The gradient term $d \alpha^{G-}_{i}$ for Token 4 and 7
  • ...and 6 more figures