Table of Contents
Fetching ...

Data-Aware Random Feature Kernel for Transformers

Amirhossein Farzam, Hossein Mobahi, Nolan Andrew Miller, Luke Sernau

TL;DR

D DARKFormer is introduced, a Data-Aware Random-feature Kernel transformer that features a data-aligned kernel geometry that learns the random-projection covariance, efficiently realizing an importance-sampled positive random-feature estimator for its data-aligned kernel.

Abstract

Transformers excel across domains, yet their quadratic attention complexity poses a barrier to scaling. Random-feature attention, as in Performers, can reduce this cost to linear in the sequence length by approximating the softmax kernel with positive random features drawn from an isotropic distribution. In pretrained models, however, queries and keys are typically anisotropic. This induces high Monte Carlo variance in isotropic sampling schemes unless one retrains the model or uses a large feature budget. Importance sampling can address this by adapting the sampling distribution to the input geometry, but complex data-dependent proposal distributions are often intractable. We show that by data aligning the softmax kernel, we obtain an attention mechanism which can both admit a tractable minimal-variance proposal distribution for importance sampling, and exhibits better training stability. Motivated by this finding, we introduce DARKFormer, a Data-Aware Random-feature Kernel transformer that features a data-aligned kernel geometry. DARKFormer learns the random-projection covariance, efficiently realizing an importance-sampled positive random-feature estimator for its data-aligned kernel. Empirically, DARKFormer narrows the performance gap with exact softmax attention, particularly in finetuning regimes where pretrained representations are anisotropic. By combining random-feature efficiency with data-aware kernels, DARKFormer advances kernel-based attention in resource-constrained settings.

Data-Aware Random Feature Kernel for Transformers

TL;DR

D DARKFormer is introduced, a Data-Aware Random-feature Kernel transformer that features a data-aligned kernel geometry that learns the random-projection covariance, efficiently realizing an importance-sampled positive random-feature estimator for its data-aligned kernel.

Abstract

Transformers excel across domains, yet their quadratic attention complexity poses a barrier to scaling. Random-feature attention, as in Performers, can reduce this cost to linear in the sequence length by approximating the softmax kernel with positive random features drawn from an isotropic distribution. In pretrained models, however, queries and keys are typically anisotropic. This induces high Monte Carlo variance in isotropic sampling schemes unless one retrains the model or uses a large feature budget. Importance sampling can address this by adapting the sampling distribution to the input geometry, but complex data-dependent proposal distributions are often intractable. We show that by data aligning the softmax kernel, we obtain an attention mechanism which can both admit a tractable minimal-variance proposal distribution for importance sampling, and exhibits better training stability. Motivated by this finding, we introduce DARKFormer, a Data-Aware Random-feature Kernel transformer that features a data-aligned kernel geometry. DARKFormer learns the random-projection covariance, efficiently realizing an importance-sampled positive random-feature estimator for its data-aligned kernel. Empirically, DARKFormer narrows the performance gap with exact softmax attention, particularly in finetuning regimes where pretrained representations are anisotropic. By combining random-feature efficiency with data-aware kernels, DARKFormer advances kernel-based attention in resource-constrained settings.
Paper Structure (31 sections, 5 theorems, 32 equations, 5 figures)

This paper contains 31 sections, 5 theorems, 32 equations, 5 figures.

Key Result

Lemma 2.1

For $\mathbf{x}, \mathbf{y} \in \mathbb{R}^d$, Hence, PRFs yield an unbiased approximation for the softmax kernel.

Figures (5)

  • Figure 1: The random feature attention replaces the softmax kernel with a linear approximation in the feature space, reducing the quadratic complexity in sequence length ($L$) to linear in sequence length times sample size ($m$).
  • Figure 2: Next token prediction accuracy during pretraining (top) and finetuning (bottom) of the Gemma-2B model with a DARKFormer (green), a Performer (orange), learned feature kernel (LFK) (blue), a random baseline (yellow), a constant baseline (lime), and an exact softmax attention. The DARKFormer model considerably narrows the gap between the exact softmax and the Performer-type model and also outperforms LFK, especially in finetuning.
  • Figure 3: Next token prediction accuracy for finetuning of the Gemma-2B (blue) model with a DARKFormer (green), a Performer (orange), and an exact softmax attention over a long cycle of 650k finetuning steps. Observe that DARKFormer outperforms Performer throughout training despite approximating a novel whitened kernel that is out of distribution for pretrained Gemma. Note that the x-axis is shown on logarithmic scale.
  • Figure 4: Next token prediction accuracy for finetuning only the q-k-v projection weights and the PRF projection covariance for the case of DARKFormer, in a Gemma-2B model with a DARKFormer (green), a Performer (orange), and an exact softmax (blue) attention over 550k finetuning steps. The DARKFormer-induced improvement is even more pronounced than finetuning the full model, and does not decrease in later finetuning iterations.
  • Figure 5: Loss dynamics for finetuning of the Gemma-2B model with DARKFormer (green) and a Performer (orange) with different learning rates. The shaded area marks the variance of loss values across seven different learning rates with the line showing the average. DARKFormer exhibits stable training and loss minimization throughout in all but the largest learning rate where there is a short instability phase, while Performer shows more frequent instability phases and loss spikes during its finetuning with large learning rates. For visualization purposes the horizontal axis is shown on logarithmic scale.

Theorems & Definitions (5)

  • Lemma 2.1: choromanski2020rethinking
  • Lemma 3.1
  • Theorem 3.2
  • Proposition 4.1
  • Proposition C.1