Table of Contents
Fetching ...

Lean Attention: Hardware-Aware Scalable Attention Mechanism for the Decode-Phase of Transformers

Rya Sanovar, Srikant Bharadwaj, Renee St. Amant, Victor Rühle, Saravan Rajmohan

TL;DR

LeanAttention addresses decode-phase bottlenecks in long-context transformer inference by introducing a hardware-aware, exact-attention mechanism that reorganizes computation around LeanTiles and a stream-K style mapping. By treating softmax re-scaling as an associative reduction, it can split KV workloads into unequal blocks while preserving exact results, enabling balanced, high-occupancy execution on GPUs and scalability to multi-GPU tensor parallelism. Empirically, LeanAttention delivers substantial latency improvements over FlashAttention-2, FlashDecoding, and FlashInfer, with average speedups around 1.7–2.0x on decode-phase workloads and up to 8.3x at extreme context lengths, while maintaining accuracy. The work offers a practical path to efficient, scalable long-context generation in decoder-only transformers, reducing energy use and enabling more capable models in real-time applications.

Abstract

Transformer-based models have emerged as one of the most widely used architectures for natural language processing, natural language generation, and image generation. The size of the state-of-the-art models has increased steadily reaching billions of parameters. These huge models are memory hungry and incur significant inference latency even on cutting edge AI-accelerators, such as GPUs. Specifically, the time and memory complexity of the attention operation is quadratic in terms of the total context length, i.e., prompt and output tokens. Thus, several optimizations such as key-value tensor caching and FlashAttention computation have been proposed to deliver the low latency demands of applications relying on such large models. However, these techniques do not cater to the computationally distinct nature of different phases during inference. To that end, we propose LeanAttention, a scalable technique of computing self-attention for the token-generation phase (decode-phase) of decoder-only transformer models. LeanAttention enables scaling the attention mechanism implementation for the challenging case of long context lengths by re-designing the execution flow for the decode-phase. We identify that the associative property of online softmax can be treated as a reduction operation thus allowing us to parallelize the attention computation over these large context lengths. We extend the "stream-K" style reduction of tiled calculation to self-attention to enable parallel computation resulting in an average of 2.6x attention execution speedup over FlashAttention-2 and up to 8.33x speedup for 512k context lengths.

Lean Attention: Hardware-Aware Scalable Attention Mechanism for the Decode-Phase of Transformers

TL;DR

LeanAttention addresses decode-phase bottlenecks in long-context transformer inference by introducing a hardware-aware, exact-attention mechanism that reorganizes computation around LeanTiles and a stream-K style mapping. By treating softmax re-scaling as an associative reduction, it can split KV workloads into unequal blocks while preserving exact results, enabling balanced, high-occupancy execution on GPUs and scalability to multi-GPU tensor parallelism. Empirically, LeanAttention delivers substantial latency improvements over FlashAttention-2, FlashDecoding, and FlashInfer, with average speedups around 1.7–2.0x on decode-phase workloads and up to 8.3x at extreme context lengths, while maintaining accuracy. The work offers a practical path to efficient, scalable long-context generation in decoder-only transformers, reducing energy use and enabling more capable models in real-time applications.

Abstract

Transformer-based models have emerged as one of the most widely used architectures for natural language processing, natural language generation, and image generation. The size of the state-of-the-art models has increased steadily reaching billions of parameters. These huge models are memory hungry and incur significant inference latency even on cutting edge AI-accelerators, such as GPUs. Specifically, the time and memory complexity of the attention operation is quadratic in terms of the total context length, i.e., prompt and output tokens. Thus, several optimizations such as key-value tensor caching and FlashAttention computation have been proposed to deliver the low latency demands of applications relying on such large models. However, these techniques do not cater to the computationally distinct nature of different phases during inference. To that end, we propose LeanAttention, a scalable technique of computing self-attention for the token-generation phase (decode-phase) of decoder-only transformer models. LeanAttention enables scaling the attention mechanism implementation for the challenging case of long context lengths by re-designing the execution flow for the decode-phase. We identify that the associative property of online softmax can be treated as a reduction operation thus allowing us to parallelize the attention computation over these large context lengths. We extend the "stream-K" style reduction of tiled calculation to self-attention to enable parallel computation resulting in an average of 2.6x attention execution speedup over FlashAttention-2 and up to 8.33x speedup for 512k context lengths.
Paper Structure (19 sections, 5 equations, 13 figures, 1 table, 2 algorithms)

This paper contains 19 sections, 5 equations, 13 figures, 1 table, 2 algorithms.

Figures (13)

  • Figure 1: Execution schedule of FlashAttention-2dao2023flashattention2, FlashDecodingflashdecoding and FlashInferflashinfer (fixed-split), and LeanAttention across a hypothetical five SM GPU executing attention of 2 heads. LeanAttention splits the context into optimal LeanTiles (shown here with 5 tiles per head).
  • Figure 2: Timeshare of decode attention compared to other stages for different prompt sizes with 8:1 token ratio for Phi-3 Medium model with single batch size.
  • Figure 3: Utilization of various resources on a single Nvidia-A100-80GB GPU in LeanAttention compared to FlashDecoding kernel at Heads=56 and BS=1 measured using Nsight Compute. FlashDecoding has a quantization efficiency issue with the 108 SMs on the GPU. LeanAttention occupies all SMs available in the system.
  • Figure 4: Illustrative diagram showing LeanAttention's partitioning strategy with two differently sized work volumes of a head assigned to different CTAs. The un-scaled outputs are independently computed and re-scaled later in a reduction operation. Note that this can be generalized to any arbitrary-sized work volume split.
  • Figure 5: Control and dataflow of a single CTA in LeanAttention utilizing various hardware resources. The tensors are loaded to shared memory in a tiled manner. At the end of a head, a reduction is performed if it is a host CTA or the partial un-scaled results are written to memory before moving to the next head.
  • ...and 8 more figures