Table of Contents
Fetching ...

KDEformer: Accelerating Transformers via Kernel Density Estimation

Amir Zandieh, Insu Han, Majid Daliri, Amin Karbasi

TL;DR

KDEformer tackles the quadratic bottleneck of Transformer attention by recasting the attention output as a spectral problem linked to Kernel Density Estimation. It develops a KDE-based reduction and sampling framework that yields provable spectral-norm guarantees for Att(Q,K,V) while achieving sub-quadratic runtime, aided by a Weighted Exponential KDE primitive and Approximate Matrix Multiplication. The theoretical contributions characterize the sampling complexity and stability via the softmax structure, and the practical enhancements use sparsity and LSH to reduce stable rank, yielding substantial memory and speedups with minimal accuracy loss on vision and generative models. Empirical results across BigGAN, ImageNet ViT, and Long Range Arena demonstrate sizable improvements in efficiency and competitive performance, highlighting the method’s potential for scalable long-sequence attention in real-world applications.

Abstract

Dot-product attention mechanism plays a crucial role in modern deep architectures (e.g., Transformer) for sequence modeling, however, naïve exact computation of this model incurs quadratic time and memory complexities in sequence length, hindering the training of long-sequence models. Critical bottlenecks are due to the computation of partition functions in the denominator of softmax function as well as the multiplication of the softmax matrix with the matrix of values. Our key observation is that the former can be reduced to a variant of the kernel density estimation (KDE) problem, and an efficient KDE solver can be further utilized to accelerate the latter via subsampling-based fast matrix products. Our proposed KDEformer can approximate the attention in sub-quadratic time with provable spectral norm bounds, while all prior results merely provide entry-wise error bounds. Empirically, we verify that KDEformer outperforms other attention approximations in terms of accuracy, memory, and runtime on various pre-trained models. On BigGAN image generation, we achieve better generative scores than the exact computation with over $4\times$ speedup. For ImageNet classification with T2T-ViT, KDEformer shows over $18\times$ speedup while the accuracy drop is less than $0.5\%$.

KDEformer: Accelerating Transformers via Kernel Density Estimation

TL;DR

KDEformer tackles the quadratic bottleneck of Transformer attention by recasting the attention output as a spectral problem linked to Kernel Density Estimation. It develops a KDE-based reduction and sampling framework that yields provable spectral-norm guarantees for Att(Q,K,V) while achieving sub-quadratic runtime, aided by a Weighted Exponential KDE primitive and Approximate Matrix Multiplication. The theoretical contributions characterize the sampling complexity and stability via the softmax structure, and the practical enhancements use sparsity and LSH to reduce stable rank, yielding substantial memory and speedups with minimal accuracy loss on vision and generative models. Empirical results across BigGAN, ImageNet ViT, and Long Range Arena demonstrate sizable improvements in efficiency and competitive performance, highlighting the method’s potential for scalable long-sequence attention in real-world applications.

Abstract

Dot-product attention mechanism plays a crucial role in modern deep architectures (e.g., Transformer) for sequence modeling, however, naïve exact computation of this model incurs quadratic time and memory complexities in sequence length, hindering the training of long-sequence models. Critical bottlenecks are due to the computation of partition functions in the denominator of softmax function as well as the multiplication of the softmax matrix with the matrix of values. Our key observation is that the former can be reduced to a variant of the kernel density estimation (KDE) problem, and an efficient KDE solver can be further utilized to accelerate the latter via subsampling-based fast matrix products. Our proposed KDEformer can approximate the attention in sub-quadratic time with provable spectral norm bounds, while all prior results merely provide entry-wise error bounds. Empirically, we verify that KDEformer outperforms other attention approximations in terms of accuracy, memory, and runtime on various pre-trained models. On BigGAN image generation, we achieve better generative scores than the exact computation with over speedup. For ImageNet classification with T2T-ViT, KDEformer shows over speedup while the accuracy drop is less than .
Paper Structure (31 sections, 11 theorems, 42 equations, 7 figures, 3 tables, 3 algorithms)

This paper contains 31 sections, 11 theorems, 42 equations, 7 figures, 3 tables, 3 algorithms.

Key Result

Theorem 2.1

Let $\tau = 0.173+o(1)$. For any dataset ${\rm X} \in \mathbb{R}^{n \times d}$ and any $\varepsilon, \widetilde{\mu} \in (0,1)$, there exist the following procedures:

Figures (7)

  • Figure 1: Image generations by the pre-trained BigGAN using exact and approximate attention without fine-tuning.
  • Figure 2: Singular values distribution and stable rank of the softmax matrix ${\rm D}^{-1} {\rm A}$ versus those of the residual ${\rm D}^{-1} {\rm A}_{\tt res}$. The stable rank of the residual matrix is significantly smaller.
  • Figure 3: The softmax matrix ${\rm D}^{-1} {\rm A}$ decomposes into its sparse approximation ${\rm D}^{-1} {\rm A}_{\tt spar}$, which captures large entries (coded with darker colors), and the residual ${\rm D}^{-1} {\rm A}_{\tt res}$, where black cells represent entries captured by ${\rm D}^{-1} {\rm A}_{\tt spar}$. Blank colors in ${\rm D}^{-1} {\rm A}_{\tt res}$ represent columns not sampled by AMM sampling matrix ${\rm \Pi}_{\tt res}$.
  • Figure 4: Performance evaluations of various self-attention approximations on approximating under the GloVe word embeddings.
  • Figure 5: Rank-$2$ Angular LSH in action (in dimension $d=2$). The space partitions corresponding to buckets with unit Hamming distance are neighbors in $\mathbb{R}^d$. In \ref{['fig:lsh_init_hashing']} we hash an example dataset and we get uneven buckets. \ref{['fig:lsh_example_truncation']} show that if we order the dataset according to the Hamming distance of their buckets and then truncate the buckets we get new equal-sized buckets with minimal spillover effect.
  • ...and 2 more figures

Theorems & Definitions (21)

  • Theorem 2.1: Fast Gaussian KDE, Theorem 2 in charikar2020kernel
  • Definition 3.1: Weighted Exponential KDE
  • Lemma 3.1: AMM
  • Theorem 3.2: Correctness of \ref{['alg-outer-loop']}
  • Theorem 3.3: Analysis of \ref{['alg-w-exp-kde']}
  • proof
  • Theorem 3.4: Approximate Attention with Spectral Norm Bound
  • Corollary 3.4: Simplified Runtime for Bounded Diameter Datasets
  • Definition 7.1: Angular LSH
  • Claim 1
  • ...and 11 more