LUNA: Linear Universal Neural Attention with Generalization Guarantees
Ashkan Shahbazi, Ping He, Ali Abbasi, Yikun Bai, Xinran Liu, Elaheh Akbari, Darian Salehi, Navid NaderiAlizadeh, Soheil Kolouri
TL;DR
The paper tackles the quadratic cost of softmax attention by introducing LUNA, a linear-time attention mechanism built around a fully learnable kernel feature map. By learning input projections, a bank of channel functions, and a tokenwise envelope, LUNA preserves the linear compute pattern while adapting to data-specific inductive biases; it also supports effective post-hoc conversion from pretrained quadratic models. The authors provide a theoretical framework including a Rademacher-based bound and an error decomposition between parametrization and sampling, and they substantiate their approach with state-of-the-art results on Long Range Arena under compute parity and strong post-hoc conversion performance on BERT/GLUE and ViT/ImageNet-1K. Overall, LUNA enables accurate, scalable attention for long sequences and practical deployment in existing systems through minimal fine-tuning after conversion.
Abstract
Scaling attention faces a critical bottleneck: the $\mathcal{O}(n^2)$ quadratic computational cost of softmax attention, which limits its application in long-sequence domains. While linear attention mechanisms reduce this cost to $\mathcal{O}(n)$, they typically rely on fixed random feature maps, such as random Fourier features or hand-crafted functions. This reliance on static, data-agnostic kernels creates a fundamental trade-off, forcing practitioners to sacrifice significant model accuracy for computational efficiency. We introduce \textsc{LUNA}, a kernelized linear attention mechanism that eliminates this trade-off, retaining linear cost while matching and surpassing the accuracy of quadratic attention. \textsc{LUNA} is built on the key insight that the kernel feature map itself should be learned rather than fixed a priori. By parameterizing the kernel, \textsc{LUNA} learns a feature basis tailored to the specific data and task, overcoming the expressive limitations of fixed-feature methods. \textsc{Luna} implements this with a learnable feature map that induces a positive-definite kernel and admits a streaming form, yielding linear time and memory scaling in the sequence length. Empirical evaluations validate our approach across diverse settings. On the Long Range Arena (LRA), \textsc{Luna} achieves state-of-the-art average accuracy among efficient Transformers under compute parity, using the same parameter count, training steps, and approximate FLOPs. \textsc{Luna} also excels at post-hoc conversion: replacing softmax in fine-tuned BERT and ViT-B/16 checkpoints and briefly fine-tuning recovers most of the original performance, substantially outperforming fixed linearizations.
