Table of Contents
Fetching ...

FLuRKA: Fast and accurate unified Low-Rank & Kernel Attention

Ahan Gupta, Hao Guo, Yueming Yuan, Yanqi Zhou, Charith Mendis

TL;DR

FLuRKA introduces a unified attention mechanism that fuses low-rank and kernel approximations to produce transformers whose step-time is faster than either constituent method. The authors provide a theoretical speed bound and an accuracy bound relative to full-attention, and validate three variants that achieve notable speedups (up to 3.3x vs kernel and 1.7x vs LR) while maintaining competitive accuracy across language modeling, understanding, long-sequence tasks, machine translation, and image classification. Empirical results show FLuRKA variants outperform or match their base components on six benchmarks, with training efficiency enabled by reduced FLOPs and an up-training strategy that blends base models with FLuRKA. The work demonstrates practical impact by enabling faster, scalable transformer models across text and vision tasks, reducing computational costs for high-quality models.

Abstract

Many efficient $\textit{approximate}$ self-attention techniques have become prevalent since the inception of the transformer architecture. Two popular classes of these techniques are low-rank and kernel methods. Each of these methods has its strengths. We observe these strengths synergistically complement each other and exploit them to fuse low-rank and kernel methods, producing a new class of transformers: FLuRKA ($\textbf{F}$ast $\textbf{L}$ow-$\textbf{R}$ank & $\textbf{K}$ernel$ \textbf{A}$ttention). FLuRKA are highly $\textit{training-efficient}$ with faster model speeds $\textit{and}$ similar model qualities compared to constituent low-rank and kernel methods. We theoretically and empirically evaluate the speed and quality of FLuRKA. Our model speed analysis posits a variety of parameter configurations where FLuRKA exhibit speedups over low-rank and kernel approximations and our model quality analysis bounds the error of FLuRKA with respect to full-attention. Empirically, we instantiate three FLuRKA variants which experience speedups of up to 3.3x and 1.7x over low-rank and kernel methods respectively. This translates to speedups of up to 20x over models with flash-attention. Across a diverse set of tasks spanning language modeling, language understanding, long sequence modeling, machine translation, and image classification, FLuRKA achieve comparable accuracy with underlying low-rank and kernel approximations, occasionally surpassing both.

FLuRKA: Fast and accurate unified Low-Rank & Kernel Attention

TL;DR

FLuRKA introduces a unified attention mechanism that fuses low-rank and kernel approximations to produce transformers whose step-time is faster than either constituent method. The authors provide a theoretical speed bound and an accuracy bound relative to full-attention, and validate three variants that achieve notable speedups (up to 3.3x vs kernel and 1.7x vs LR) while maintaining competitive accuracy across language modeling, understanding, long-sequence tasks, machine translation, and image classification. Empirical results show FLuRKA variants outperform or match their base components on six benchmarks, with training efficiency enabled by reduced FLOPs and an up-training strategy that blends base models with FLuRKA. The work demonstrates practical impact by enabling faster, scalable transformer models across text and vision tasks, reducing computational costs for high-quality models.

Abstract

Many efficient self-attention techniques have become prevalent since the inception of the transformer architecture. Two popular classes of these techniques are low-rank and kernel methods. Each of these methods has its strengths. We observe these strengths synergistically complement each other and exploit them to fuse low-rank and kernel methods, producing a new class of transformers: FLuRKA (ast ow-ank & ernelttention). FLuRKA are highly with faster model speeds similar model qualities compared to constituent low-rank and kernel methods. We theoretically and empirically evaluate the speed and quality of FLuRKA. Our model speed analysis posits a variety of parameter configurations where FLuRKA exhibit speedups over low-rank and kernel approximations and our model quality analysis bounds the error of FLuRKA with respect to full-attention. Empirically, we instantiate three FLuRKA variants which experience speedups of up to 3.3x and 1.7x over low-rank and kernel methods respectively. This translates to speedups of up to 20x over models with flash-attention. Across a diverse set of tasks spanning language modeling, language understanding, long sequence modeling, machine translation, and image classification, FLuRKA achieve comparable accuracy with underlying low-rank and kernel approximations, occasionally surpassing both.
Paper Structure (26 sections, 2 theorems, 48 equations, 5 figures, 6 tables)

This paper contains 26 sections, 2 theorems, 48 equations, 5 figures, 6 tables.

Key Result

Theorem 1

Suppose we have a random feature map $\phi$ defined as follows: such that: Then for any $Q_i$, $K_i$, $V_i$$\in \mathbb{R}^{n\times d_m}$ and $W_i^Q$, $W_i^K$, $W_i^V$$\in \mathbb{R}^{d_m \times d_h}$, and $k = 5\log(d)/(\epsilon_2^2 - \epsilon_3^2)$. We have, for the matrices $E_i = \delta R, F_i = e^{-\delta}R$ where $R \in \mathbb{R}^{n \times k}$ whose entries are iid Occurs with probabilit

Figures (5)

  • Figure 1: We take 12-layer 12-head performer pre-trained on wikitext-103 and plot the number of unique singular values in the SVD of the kernelized attention matrix: $\phi(QW^Q)\phi(KW^K)$ for every alternate attention head in each layer.
  • Figure 2: The pipeline of operations involved in constructing FLuRKA, parameterized by a kernel $\phi$, compared to low-rank and kernel methods. $N$, $d_m$, $d_k$ are the sequence length, hidden dimension, and downsampling factors, respectively. $d_p$ is the dimension of the kernelized queries and keys.
  • Figure 3: Comparing inference times of all variants as sequence lengths increase. All the models have the same parameter count and were run for 100 iterations of inference. Our method (FLuRKA) are in green, low-rank is in blue & kernel is in orange. Lower is better.
  • Figure 4: The impact of the downsampling factor and hidden dimension on runtime performance (model speeds) normalized to speedups over full-attention. The top figure compares FLuRKA to low-rank methods and the bottom figure compares FLuRKA to kernel methods. Our methods (FLuRKA) are highlighted in hashed bars, while low-rank and kernel methods are highlighted in solid bars.
  • Figure 5: Speedups FLuRKA attain over Flash-Attention.

Theorems & Definitions (6)

  • Claim 1
  • Theorem 1
  • Claim 2
  • Claim 3
  • Definition 1
  • Theorem : Linformer