Table of Contents
Fetching ...

FLARE: Fast Low-rank Attention Routing Engine

Vedant Puri, Aditya Joglekar, Kevin Ferguson, Yu-hsuan Chen, Yongjie Jessica Zhang, Levent Burak Kara

TL;DR

FLARE addresses the quadratic cost of self-attention on large unstructured meshes by routing attention through a fixed-length latent sequence per head, achieving linear $O(NM)$ time and enabling scalable PDE surrogate modeling. It introduces encoding and decoding cross-attentions with fixed latent queries, yielding a low-rank communication pattern that preserves expressivity through head-wise independent projections and deep residual MLPs for key/value projections. Spectral analysis confirms a low-rank, diverse, head-specific attention structure, and experiments show FLARE achieving state-of-the-art accuracy across diverse PDE benchmarks while scaling to one million points on a single GPU. The work also releases a large LPBF additive manufacturing dataset to spur further research and provides open-source code for integration with standard fused attention kernels.

Abstract

The quadratic complexity of self-attention limits its applicability and scalability on large unstructured meshes. We introduce Fast Low-rank Attention Routing Engine (FLARE), a linear complexity self-attention mechanism that routes attention through fixed-length latent sequences. Each attention head performs global communication among $N$ tokens by projecting the input sequence onto a fixed length latent sequence of $M \ll N$ tokens using learnable query tokens. By routing attention through a bottleneck sequence, FLARE learns a low-rank form of attention that can be applied at $O(NM)$ cost. FLARE not only scales to unprecedented problem sizes, but also delivers superior accuracy compared to state-of-the-art neural PDE surrogates across diverse benchmarks. We also release a new additive manufacturing dataset to spur further research. Our code is available at https://github.com/vpuri3/FLARE.py.

FLARE: Fast Low-rank Attention Routing Engine

TL;DR

FLARE addresses the quadratic cost of self-attention on large unstructured meshes by routing attention through a fixed-length latent sequence per head, achieving linear time and enabling scalable PDE surrogate modeling. It introduces encoding and decoding cross-attentions with fixed latent queries, yielding a low-rank communication pattern that preserves expressivity through head-wise independent projections and deep residual MLPs for key/value projections. Spectral analysis confirms a low-rank, diverse, head-specific attention structure, and experiments show FLARE achieving state-of-the-art accuracy across diverse PDE benchmarks while scaling to one million points on a single GPU. The work also releases a large LPBF additive manufacturing dataset to spur further research and provides open-source code for integration with standard fused attention kernels.

Abstract

The quadratic complexity of self-attention limits its applicability and scalability on large unstructured meshes. We introduce Fast Low-rank Attention Routing Engine (FLARE), a linear complexity self-attention mechanism that routes attention through fixed-length latent sequences. Each attention head performs global communication among tokens by projecting the input sequence onto a fixed length latent sequence of tokens using learnable query tokens. By routing attention through a bottleneck sequence, FLARE learns a low-rank form of attention that can be applied at cost. FLARE not only scales to unprecedented problem sizes, but also delivers superior accuracy compared to state-of-the-art neural PDE surrogates across diverse benchmarks. We also release a new additive manufacturing dataset to spur further research. Our code is available at https://github.com/vpuri3/FLARE.py.

Paper Structure

This paper contains 62 sections, 25 equations, 12 figures, 6 tables, 1 algorithm.

Figures (12)

  • Figure 1: Schematic of a FLARE block. In FLARE, each head projects the input sequence with $N$ tokens to a fixed-length sequence of $M$ tokens via the cross-attention matrix $W_\text{encode} = \mathrm{softmax}(Q \cdot K^T)$, and then projects back to the original length via the cross-attention matrix $W_\text{decode} = \mathrm{softmax}(K \cdot Q^T)$. The overall operation is equivalent to token mixing on the input sequence with the rank-deficient matrix $\left(W_\text{decode} \cdot W_\text{encode}\right)$.
  • Figure 2: Time and memory requirements of different attention schemes. On an input sequence of one million tokens, FLARE (red) is over $200\times$ faster than vanilla attention, while consuming marginally more memory. All models are implemented with flash attention dao2022flashattention, and the memory upper bound on a single H100 80GB GPU is depicted with a dashed line. Note that the curves for FLARE are somewhat overlapping.
  • Figure 3: PyTorch code for multi-head token mixing operation in FLARE. See \ref{['fig:FLARE_pseudocode2']} for an implementation without the fused attention kernel.
  • Figure 4: We train FLARE on the DrivAerML dataset ashton2024drivaerml with one million points per geometry on a single Nvidia H100 80GB GPU. We present (left) the test relative error, (middle) time per epoch (s), and (right) peak memory utilization (GB) as a function of the number of FLARE blocks ($B$) for different number of latent tokens ($M$).
  • Figure 5: Effect of number of blocks ($B$) and number of latent tokens ($M$) on test accuracy on the elasticity (left) and darcy (right) test cases.
  • ...and 7 more figures