Table of Contents
Fetching ...

$k$NN Attention Demystified: A Theoretical Exploration for Scalable Transformers

Themistoklis Haris

TL;DR

A theoretical framework for $k$NN attention is established, reformulating self-attention as expectations over softmax distributions and leveraging lazy Gumbel sampling with $k$NN indices for efficient approximation is established and novel sub-quadratic algorithms that approximate self-attention gradients are proposed.

Abstract

Despite their power, Transformers face challenges with long sequences due to the quadratic complexity of self-attention. To address this limitation, methods like $k$-Nearest-Neighbor ($k$NN) attention have been introduced [Roy, Saffar, Vaswani, Grangier, 2021] enabling each token to attend to only its $k$ closest tokens. While $k$NN attention has shown empirical success in making Transformers more efficient, its exact approximation guarantees have not been theoretically analyzed. In this work, we establish a theoretical framework for $k$NN attention, reformulating self-attention as expectations over softmax distributions and leveraging lazy Gumbel sampling [Mussmann, Levy, Ermon, 2017] with $k$NN indices for efficient approximation. Building on this framework, we also propose novel sub-quadratic algorithms that approximate self-attention gradients by leveraging efficient sampling techniques, such as Markov Chain-based estimation. Finally, we demonstrate the practical effectiveness of these algorithms through empirical experiments, showcasing their benefits in both training and inference.

$k$NN Attention Demystified: A Theoretical Exploration for Scalable Transformers

TL;DR

A theoretical framework for NN attention is established, reformulating self-attention as expectations over softmax distributions and leveraging lazy Gumbel sampling with NN indices for efficient approximation is established and novel sub-quadratic algorithms that approximate self-attention gradients are proposed.

Abstract

Despite their power, Transformers face challenges with long sequences due to the quadratic complexity of self-attention. To address this limitation, methods like -Nearest-Neighbor (NN) attention have been introduced [Roy, Saffar, Vaswani, Grangier, 2021] enabling each token to attend to only its closest tokens. While NN attention has shown empirical success in making Transformers more efficient, its exact approximation guarantees have not been theoretically analyzed. In this work, we establish a theoretical framework for NN attention, reformulating self-attention as expectations over softmax distributions and leveraging lazy Gumbel sampling [Mussmann, Levy, Ermon, 2017] with NN indices for efficient approximation. Building on this framework, we also propose novel sub-quadratic algorithms that approximate self-attention gradients by leveraging efficient sampling techniques, such as Markov Chain-based estimation. Finally, we demonstrate the practical effectiveness of these algorithms through empirical experiments, showcasing their benefits in both training and inference.

Paper Structure

This paper contains 42 sections, 21 theorems, 88 equations, 7 figures, 1 table, 8 algorithms.

Key Result

Lemma 1

If $\widehat{Q}$ is an unbiased estimator of some statistic, then one can obtain an $(\varepsilon,\delta)$-multiplicative estimate of that statistic by suitably combining $K := \frac{C}{\varepsilon^2}\frac{\text{Var}[\widehat{Q}]}{\mathbb{E}[\widehat{Q}]^2}\ln \frac{2}{\delta}$ independent samples o

Figures (7)

  • Figure 1: Lazy Gumbel sampling
  • Figure 2: A single-step Markov Chain sample.
  • Figure 3: Gradient Descent with Approximate Gradients against different loss functions $\phi$. Even with approximate gradients, gradient descent still makes adequate progress towards convergence.
  • Figure 4: The perplexity and approximation error of $k$NN Attention throughout training
  • Figure 5: An illustration of the concentric LSH construction of mussmann2017fast In the $D_{i+1}$ band we find at least $\sqrt{n}$ points and in the $D_i$ band we find fewer than $\sqrt{n}$ points.
  • ...and 2 more figures

Theorems & Definitions (39)

  • Lemma 1: Median-Of-Means Boosting, chakrabarti2020data
  • Theorem 2
  • Theorem 3
  • Theorem 4
  • proof
  • Theorem 5: Correctness of Algorithm \ref{['alg:lazy-gumbel']}
  • proof
  • Lemma 6
  • Theorem 7
  • Theorem 8
  • ...and 29 more