Neural Attention Search
Difan Deng, Marius Lindauer
TL;DR
NAtS tackles the costly KV-cache and long-context inference in transformers by learning per-token roles—Global, Local, and Sliding Window—via a differentiable token-type search guided by Gumbel-Softmax. It constructs a learnable attention mask that end-to-end sparsifies attention, enabling substantial KV-cache reductions while preserving model performance for both training-from-scratch and fine-tuning scenarios. Empirical results on PG-19 and LongBench show NAtS achieving strong perplexity or task performance at far smaller KV budgets, and latency/memory analyses demonstrate dramatic improvements in very long context settings. The approach offers a practical pathway to scalable, efficient long-context inference for LLMs and related transformer architectures, with potential applicability across prefill and decoding phases. $O(L^2)$ attention cost can thus be mitigated toward the regime dominated by KV-cache operations, enabling larger contexts and more capable models in real-world deployments.
Abstract
We present Neural Attention Search (NAtS), a framework that automatically evaluates the importance of each token within a sequence and determines if the corresponding token can be dropped after several steps. This approach can efficiently reduce the KV cache sizes required by transformer-based models during inference and thus reduce inference costs. In this paper, we design a search space that contains three token types: (i) Global Tokens will be preserved and queried by all the following tokens. (ii) Local Tokens survive until the next global token appears. (iii) Sliding Window Tokens have an impact on the inference of a fixed size of the next following tokens. Similar to the One-Shot Neural Architecture Search approach, this token-type information can be learned jointly with the architecture weights via a learnable attention mask. Experiments on both training a new transformer from scratch and fine-tuning existing large language models show that NAtS can efficiently reduce the KV cache size required for the models while maintaining the models' performance.
