Table of Contents
Fetching ...

Linear Transformer Topological Masking with Graph Random Features

Isaac Reid, Kumar Avinava Dubey, Deepali Jain, Will Whitney, Amr Ahmed, Joshua Ainslie, Alex Bewley, Mithun Jacob, Aranyak Mehta, David Rendleman, Connor Schenck, Richard E. Turner, René Wagner, Adrian Weller, Krzysztof Choromanski

TL;DR

This paper proposes to parameterise topological masks as a learnable function of a weighted adjacency matrix -- a novel, flexible approach which incorporates a strong structural inductive bias.

Abstract

When training transformers on graph-structured data, incorporating information about the underlying topology is crucial for good performance. Topological masking, a type of relative position encoding, achieves this by upweighting or downweighting attention depending on the relationship between the query and keys in a graph. In this paper, we propose to parameterise topological masks as a learnable function of a weighted adjacency matrix -- a novel, flexible approach which incorporates a strong structural inductive bias. By approximating this mask with graph random features (for which we prove the first known concentration bounds), we show how this can be made fully compatible with linear attention, preserving $\mathcal{O}(N)$ time and space complexity with respect to the number of input tokens. The fastest previous alternative was $\mathcal{O}(N \log N)$ and only suitable for specific graphs. Our efficient masking algorithms provide strong performance gains for tasks on image and point cloud data, including with $>30$k nodes.

Linear Transformer Topological Masking with Graph Random Features

TL;DR

This paper proposes to parameterise topological masks as a learnable function of a weighted adjacency matrix -- a novel, flexible approach which incorporates a strong structural inductive bias.

Abstract

When training transformers on graph-structured data, incorporating information about the underlying topology is crucial for good performance. Topological masking, a type of relative position encoding, achieves this by upweighting or downweighting attention depending on the relationship between the query and keys in a graph. In this paper, we propose to parameterise topological masks as a learnable function of a weighted adjacency matrix -- a novel, flexible approach which incorporates a strong structural inductive bias. By approximating this mask with graph random features (for which we prove the first known concentration bounds), we show how this can be made fully compatible with linear attention, preserving time and space complexity with respect to the number of input tokens. The fastest previous alternative was and only suitable for specific graphs. Our efficient masking algorithms provide strong performance gains for tasks on image and point cloud data, including with k nodes.
Paper Structure (22 sections, 3 theorems, 25 equations, 6 figures, 4 tables, 2 algorithms)

This paper contains 22 sections, 3 theorems, 25 equations, 6 figures, 4 tables, 2 algorithms.

Key Result

Theorem 3.1

Consider a graph $\mathcal{G}$ with adjacency matrix $\mathbf{W}$ and node degrees $\{d_i\}_{v_i \in \mathcal{N}}$. Suppose we construct GRFs $\{\widehat{\phi}_\mathcal{G}(v_i)\}_{v_i \in \mathcal{G}}$ by sampling $n$ random walks $\{\omega_k^{(i)}\}_{v_i \in \mathcal{N},\space k \in [\![1,n]\!] }$

Figures (6)

  • Figure 1: Schematic overview. Regular attention is $\mathcal{O}(N^2)$, with $N$ the number of input tokens. Topological masking modulates $\mathbf{A}$ by a graph function $\mathbf{M}(\mathcal{G})$, improving predictive performance. Linear attention reduces the time complexity to $\mathcal{O}(N)$ by leveraging a low-rank decomposition. Our contribution (blue) is the first algorithm to achieve both -- $\mathcal{O}(N)$ topological masking of low-rank attention -- by approximating $\mathbf{M}(\mathcal{G})$ with graph random features (GRFs). GRFs are sparse vectors (denoted 'sp.') computed by sampling random walks, constructed so that $\mathbb{E}(\widehat{\mathbf{\Phi}}_{\mathbf{Q}, \mathcal{G}}\widehat{\mathbf{\Phi}}_{\mathbf{K}, \mathcal{G}}^\top) = \mathbf{A} \odot \mathbf{M}(\mathcal{G})$ with strong concentration properties (Thm. \ref{['thm:conc_in']}).
  • Figure 2: Visual overview. A graph $\mathcal{G}$ (left) has a weighted adjacency matrix $\mathbf{W}$ (centre left). A learnable power series $\mathbf{M}_\alpha(\mathcal{G})\coloneqq \sum_{i=0}^\infty \alpha_i \mathbf{W}^i$ is an effective topological mask or graph RPE (centre right). $\mathbf{M}_\alpha(\mathcal{G})$ can be efficiently approximated using graph random features (centre top), which perform importance sampling of halting random walks. The feature $\widehat{\phi}_\mathcal{G}(v_i)$ is only nonzero at entries visited by the ensemble of walks beginning at $v_i$. In Thm. \ref{['thm:conc_in']}, we prove that the number of such entries is $\mathcal{O}(1)$ whilst still accurately estimating $\mathbf{M}_\alpha(\mathbf{W})_{ij}$ with high probability, so GRFs are sparse. This unlocks $\mathcal{O}(N)$ topological masking. Note that $\widehat{\phi}_\mathcal{G}(v_i)^\top \widehat{\phi}_\mathcal{G}(v_j)$ is only nonzero if the features are nonzero at some of the same coordinates, which happens if their respective ensembles of walks 'hit'. This incorporates a strong structural inductive bias. The equations on the right formalise our method mathematically.
  • Figure 3: Number of FLOPs vs. number of graph nodes for softmax attention, linear attention, and linear attention with GRF topological masking (ours).
  • Figure 4: Rendered rollouts. NeRF renderings of the predictive dynamics of a bimanual Kuka robot, conditioned on the initial scene and sequence of robot actions. We show four tasks: using a dustpan and brush, lifting a can, moving a green block, and dropping a can. GRF Interlacers model point cloud dynamics more accurately, so the predicted frame rendings are closer to the ground truth.
  • Figure 5: Accuracy comparison. Structural similarity index measure (SSIM) between ground truth camera frames and predictive NeRF renderings after $100$k training steps, plotted against rollout timestep. Higher is better. GRFs improve the accuracy of dynamics prediction compared to the baselines.
  • ...and 1 more figures

Theorems & Definitions (4)

  • Remark 3.1: Explicit $N$-dimensional features for graph node kernels reid2023universal
  • Theorem 3.1: GRF exponential concentration bounds
  • Lemma 3.2: GRF sparsity
  • Corollary 3.3: GRFs implement $\mathcal{O}(N)$ topological masking