Table of Contents
Fetching ...

Faster Neighborhood Attention: Reducing the O(n^2) Cost of Self Attention at the Threadblock Level

Ali Hassani, Wen-Mei Hwu, Humphrey Shi

TL;DR

F fused neighborhood attention is developed; an adaptation of fused dot-product attention kernels that allow fine-grained control over attention across different spatial axes that can now enjoy a reduced and constant memory footprint, and record-breaking half precision runtime.

Abstract

Neighborhood attention reduces the cost of self attention by restricting each token's attention span to its nearest neighbors. This restriction, parameterized by a window size and dilation factor, draws a spectrum of possible attention patterns between linear projection and self attention. Neighborhood attention, and more generally sliding window attention patterns, have long been bounded by infrastructure, particularly in higher-rank spaces (2-D and 3-D), calling for the development of custom kernels, which have been limited in either functionality, or performance, if not both. In this work, we aim to massively improve upon existing infrastructure by providing two new methods for implementing neighborhood attention. We first show that neighborhood attention can be represented as a batched GEMM problem, similar to standard attention, and implement it for 1-D and 2-D neighborhood attention. These kernels on average provide 895% and 272% improvement in full precision runtime compared to existing naive CUDA kernels for 1-D and 2-D neighborhood attention respectively. We find that aside from being heavily bound by memory bandwidth, certain inherent inefficiencies exist in all unfused implementations of neighborhood attention, which in most cases undo their theoretical efficiency gain. Motivated by the progress made into fused dot-product attention kernels, we developed fused neighborhood attention; an adaptation of fused dot-product attention kernels that allow fine-grained control over attention across different spatial axes. Known for reducing the quadratic time complexity of self attention to a linear complexity, neighborhood attention can now enjoy a reduced and constant memory footprint, and record-breaking half precision runtime. We observe that our fused implementation successfully circumvents some of the unavoidable inefficiencies in unfused implementations...

Faster Neighborhood Attention: Reducing the O(n^2) Cost of Self Attention at the Threadblock Level

TL;DR

F fused neighborhood attention is developed; an adaptation of fused dot-product attention kernels that allow fine-grained control over attention across different spatial axes that can now enjoy a reduced and constant memory footprint, and record-breaking half precision runtime.

Abstract

Neighborhood attention reduces the cost of self attention by restricting each token's attention span to its nearest neighbors. This restriction, parameterized by a window size and dilation factor, draws a spectrum of possible attention patterns between linear projection and self attention. Neighborhood attention, and more generally sliding window attention patterns, have long been bounded by infrastructure, particularly in higher-rank spaces (2-D and 3-D), calling for the development of custom kernels, which have been limited in either functionality, or performance, if not both. In this work, we aim to massively improve upon existing infrastructure by providing two new methods for implementing neighborhood attention. We first show that neighborhood attention can be represented as a batched GEMM problem, similar to standard attention, and implement it for 1-D and 2-D neighborhood attention. These kernels on average provide 895% and 272% improvement in full precision runtime compared to existing naive CUDA kernels for 1-D and 2-D neighborhood attention respectively. We find that aside from being heavily bound by memory bandwidth, certain inherent inefficiencies exist in all unfused implementations of neighborhood attention, which in most cases undo their theoretical efficiency gain. Motivated by the progress made into fused dot-product attention kernels, we developed fused neighborhood attention; an adaptation of fused dot-product attention kernels that allow fine-grained control over attention across different spatial axes. Known for reducing the quadratic time complexity of self attention to a linear complexity, neighborhood attention can now enjoy a reduced and constant memory footprint, and record-breaking half precision runtime. We observe that our fused implementation successfully circumvents some of the unavoidable inefficiencies in unfused implementations...
Paper Structure (13 sections, 4 equations, 4 figures, 8 tables)

This paper contains 13 sections, 4 equations, 4 figures, 8 tables.

Figures (4)

  • Figure 1: Overview of average improvement in speed on A100 from our proposed implementation. Baseline is the set of naive CUDA kernels introduced in Neighborhood Attention Transformer hassani2023neighborhood. GEMM-based NA improves 1-D problems by an average of 548% (forward pass) and 502% (forward + backward), and 2-D problems by an average of 193% (forward pass) and 92% (forward + backward). GEMM-based NA does not implement 3-D problems yet. Fused NA boosts performance further and improves 1-D problems by an average of 1759% (forward pass) and 844% (forward + backward), and 2-D problems by an average of 958% (forward pass) and 385% (forward + backward), and 3-D problems by an average of 1135% (forward pass) and 447% (forward + backward).
  • Figure 2: Illustration of the spectrum of possible attention patterns provided by neighborhood attention. Neighborhood attention only attempts to center the query token (red) within the context window (blue), unlike sliding window attention ramachandran2019stand which forces it. Neighborhood attention with window size 1 is equivalent to linear projection ("no attention"). Neighborhood attention approaches self attention as window size grows, and matches it when equal to input size. Dilation introduces sparse global context, and causal masking prevents interaction between query tokens that have a smaller coordinate than neighboring context tokens along the corresponding mode. Window size, dilation, and whether or not causally masked, can be defined per mode/axis.
  • Figure 3: Illustration of our GEMM-based implementation of the 2-D PN operation. Input tensors $Q$ and $K$ are tiled according to their 2-D spatial layout. $Q$ is tiled with a static tile shape, $T_h \times T_w$. $K$ is tiled with a haloed shape of the $Q$ tile, $T^{\prime}_h \times T^{\prime}_w$, which is a function of the attention window size ($k_h \times k_w$) and the $Q$ tile coordinates. Once tiles are moved into local memory, they are viewed in matrix layout, and a $T_h T_w \times T^{\prime}_h T^{\prime}_w \times d$ shaped GEMM is computed ($d$ is embedding dim). Once done, the tile of dot products with shape $T_h T_w \times T^{\prime}_h T^{\prime}_w$ is scattered into valid attention weights of shape $T_h \times T_w \times k_h k_w$.
  • Figure 4: A simplified illustration of fused neighborhood attention.$Q$ and $KV$ tensors are tiled according to their spatial layout (1-D, 2-D, 3-D), with the latter haloed to include the entire neighborhood for all corresponding queries in the query tile. Resulting attention weights from the first GEMM are masked according to neighborhood attention parameters, before undergoing online softmax scaling, and going through the second GEMM with the corresponding value sub-tile.