Table of Contents
Fetching ...

Learning a Fourier Transform for Linear Relative Positional Encodings in Transformers

Krzysztof Marcin Choromanski, Shanda Li, Valerii Likhosherstov, Kumar Avinava Dubey, Shengjie Luo, Di He, Yiming Yang, Tamas Sarlos, Thomas Weingarten, Adrian Weller

TL;DR

This work tackles the Transformer attention bottleneck by introducing FourierLearner-Transformers (FLTs), which learn a spectral representation of relative positional encodings (RPEs) to flexibly integrate a wide range of RPEs without sacrificing linear attention efficiency. The authors establish a Performer-friendly framework that decomposes the RPE mask via random Fourier features, enabling low-rank, unbiased approximations and providing uniform convergence guarantees with a clear sample complexity bound. They explore Gaussian mixture RPEs, shift-invariant kernels, and local RPEs, and demonstrate FLTs on language modeling, image classification, molecular property prediction, and learnable optimizers, reporting strong accuracy gains and favorable computational costs. The approach broadens the applicability of RPE-enhanced linear attention, including challenging 3D geometric data, and offers practical, scalable Transformer variants for multiple modalities with minimal parameter overhead for the RPE component.

Abstract

We propose a new class of linear Transformers called FourierLearner-Transformers (FLTs), which incorporate a wide range of relative positional encoding mechanisms (RPEs). These include regular RPE techniques applied for sequential data, as well as novel RPEs operating on geometric data embedded in higher-dimensional Euclidean spaces. FLTs construct the optimal RPE mechanism implicitly by learning its spectral representation. As opposed to other architectures combining efficient low-rank linear attention with RPEs, FLTs remain practical in terms of their memory usage and do not require additional assumptions about the structure of the RPE mask. Besides, FLTs allow for applying certain structural inductive bias techniques to specify masking strategies, e.g. they provide a way to learn the so-called local RPEs introduced in this paper and give accuracy gains as compared with several other linear Transformers for language modeling. We also thoroughly test FLTs on other data modalities and tasks, such as image classification, 3D molecular modeling, and learnable optimizers. To the best of our knowledge, for 3D molecular data, FLTs are the first Transformer architectures providing linear attention and incorporating RPE masking.

Learning a Fourier Transform for Linear Relative Positional Encodings in Transformers

TL;DR

This work tackles the Transformer attention bottleneck by introducing FourierLearner-Transformers (FLTs), which learn a spectral representation of relative positional encodings (RPEs) to flexibly integrate a wide range of RPEs without sacrificing linear attention efficiency. The authors establish a Performer-friendly framework that decomposes the RPE mask via random Fourier features, enabling low-rank, unbiased approximations and providing uniform convergence guarantees with a clear sample complexity bound. They explore Gaussian mixture RPEs, shift-invariant kernels, and local RPEs, and demonstrate FLTs on language modeling, image classification, molecular property prediction, and learnable optimizers, reporting strong accuracy gains and favorable computational costs. The approach broadens the applicability of RPE-enhanced linear attention, including challenging 3D geometric data, and offers practical, scalable Transformer variants for multiple modalities with minimal parameter overhead for the RPE component.

Abstract

We propose a new class of linear Transformers called FourierLearner-Transformers (FLTs), which incorporate a wide range of relative positional encoding mechanisms (RPEs). These include regular RPE techniques applied for sequential data, as well as novel RPEs operating on geometric data embedded in higher-dimensional Euclidean spaces. FLTs construct the optimal RPE mechanism implicitly by learning its spectral representation. As opposed to other architectures combining efficient low-rank linear attention with RPEs, FLTs remain practical in terms of their memory usage and do not require additional assumptions about the structure of the RPE mask. Besides, FLTs allow for applying certain structural inductive bias techniques to specify masking strategies, e.g. they provide a way to learn the so-called local RPEs introduced in this paper and give accuracy gains as compared with several other linear Transformers for language modeling. We also thoroughly test FLTs on other data modalities and tasks, such as image classification, 3D molecular modeling, and learnable optimizers. To the best of our knowledge, for 3D molecular data, FLTs are the first Transformer architectures providing linear attention and incorporating RPE masking.
Paper Structure (47 sections, 6 theorems, 36 equations, 6 figures, 5 tables)

This paper contains 47 sections, 6 theorems, 36 equations, 6 figures, 5 tables.

Key Result

Theorem 4.1

Given $f:\mathbb{R}^{\ell} \rightarrow \mathbb{R}$ and $\mathbf{N} = [f(\mathbf{r}_{i}-\mathbf{r}_{j})] \in \mathbb{R}^{L \times L}$ as defined in Definition gen_graph_attention, denote by $g$ the Fourier Transform of $f$. Assume $p$ is some probability density function supported over $\mathbb{R}^{\ Define $\mathbf{N}_{1}= \left[\varphi(\mathbf{r}_{1}),\cdots,\varphi(\mathbf{r}_{L})\right]^{\top}

Figures (6)

  • Figure 1: Model forward speed (left) and peak memory (right) comparisons between $\mathrm{FLT}$ and baselines under different input sequence lengths.
  • Figure 2: Validation loss of FLTs and the regular Performer on the IS2RE task of OC20 dataset.
  • Figure 3: Comparisons of FLT with the regular Performer on OC20 IS2RE task. The suffix "-$k$L" means the model consists of $k$ layers, e.g., $\mathrm{FLT}$ -10L refers to a 10-layer $\mathrm{FLT}$. The evaluation metrics are Mean Absolute Error (MAE, lower is better) of the energies and the percentage of Energies within a Threshold (EwT, higher is better). We highlighted in bold the best performance.
  • Figure 3: Results of learnable optimizer experiments. Left:Adam & learnable optimizers using $\mathrm{FLT}$ and S4 on the task of training ViT-Base classifier on ImageNet. Right: Adam & various learnable optimizers on the task of optimizing Rastrigin-type functions (from private conversation with the authors of jain2023mnemosyne).
  • Figure 4: Examples of the local RPE mechanisms discussed in Sec. \ref{['sec:topology']} and supported via $\mathrm{FLTs}$. Both examples are for tokens with positions described by two coordinates ($\ell=2$). The $x$ and $y$ coordinates encode the difference vector $\Delta \mathbf{r}=(\Delta r_1, \Delta r_2)^{\top}$. The $z$-coordinate provides the value of a function $f$. Left: (non-continuous) $f_{\mathbf{v},C}(\Delta\mathbf{r})=C \cdot \mathbb{I}_{\{|\Delta r_{1}| \leq v_{1}\}}\mathbb{I}_{\{|\Delta r_{2}| \leq v_{2}\}}$ for some $\mathbf{v}=(v_{1},v_{2})^{\top}$. Right: (continuous) $f_{\mathbf{v}}(\Delta \mathbf{r})=\mathbb{I}_{\{|\Delta r_{1}| \leq v_{1}\}}\mathbb{I}_{\{|\Delta r_{2}| \leq v_{2}\}}(-|\Delta r_{1}| + v_{1})(-|\Delta r_{2}| + v_{2})$. Both local RPE functions vanish outside the bounded region.
  • ...and 1 more figures

Theorems & Definitions (11)

  • Definition 3.1: General RPE for attention
  • Theorem 4.1
  • Theorem 4.2: Uniform convergence and sample complexity for approximation
  • Theorem A.1
  • proof
  • Lemma A.2
  • proof
  • Theorem A.3: Variance of RPE approximation
  • proof
  • Theorem A.4: Uniform convergence and sample complexity for approximation
  • ...and 1 more