Table of Contents
Fetching ...

FLASH-D: FlashAttention with Hidden Softmax Division

Kosmas Alexandridis, Vasileios Titopoulos, Giorgos Dimitrakopoulos

TL;DR

FLASH-D re-expresses the FlashAttention kernel to hide the softmax division inside sigmoid-based weight computations, preserving numerical stability and the benefits of tiled, online softmax. The approach maintains the same attention semantics while removing explicit division and running-sum tracking, enabling simpler and more area- and power-efficient hardware implementations. Hardware evaluations on 28 nm ASIC show substantial area and power reductions relative to FlashAttention2 without sacrificing accuracy or latency. The method also enables potential computation skipping in inference, offering practical gains for large language models without introducing approximations. Overall, FLASH-D provides a hardware-friendly reformulation that preserves core attention properties while improving efficiency for long-sequence transformers.

Abstract

The transformer's attention mechanism has revolutionized AI and machine learning, with its efficient computation being crucial to its performance. However, calculating attention involves matrix operations interspersed with softmax rescaling, which inherently slows down computation and requires processing the entire input sequence. Building on online softmax computation, FlashAttention integrates softmax calculation with matrix arithmetic, enabling tiled computation independent of sequence length. While optimized for GPUs, FlashAttention's simplicity makes it amenable to direct hardware acceleration. This work re-evaluates the core FlashAttention kernel, presenting FLASH-D a mathematically equivalent, yet simplified, formulation that achieves: (a) hiding softmax division within other non-linear function evaluations; (b) inherently numerically stable computation of exponentials, eliminating the need for maximum value subtraction; and (c) a reduction in computational cost without introducing numerical approximations to the FlashAttention kernel. Importantly, the essential FlashAttention properties that facilitate efficient tiled implementation are fully preserved. Hardware implementation results at 28nm demonstrate that this proposed formulation achieves a 22.8% reduction in area and a 20.3% reduction in power, on average, compared to state-of-the-art parallel hardware architectures without any performance penalty.

FLASH-D: FlashAttention with Hidden Softmax Division

TL;DR

FLASH-D re-expresses the FlashAttention kernel to hide the softmax division inside sigmoid-based weight computations, preserving numerical stability and the benefits of tiled, online softmax. The approach maintains the same attention semantics while removing explicit division and running-sum tracking, enabling simpler and more area- and power-efficient hardware implementations. Hardware evaluations on 28 nm ASIC show substantial area and power reductions relative to FlashAttention2 without sacrificing accuracy or latency. The method also enables potential computation skipping in inference, offering practical gains for large language models without introducing approximations. Overall, FLASH-D provides a hardware-friendly reformulation that preserves core attention properties while improving efficiency for long-sequence transformers.

Abstract

The transformer's attention mechanism has revolutionized AI and machine learning, with its efficient computation being crucial to its performance. However, calculating attention involves matrix operations interspersed with softmax rescaling, which inherently slows down computation and requires processing the entire input sequence. Building on online softmax computation, FlashAttention integrates softmax calculation with matrix arithmetic, enabling tiled computation independent of sequence length. While optimized for GPUs, FlashAttention's simplicity makes it amenable to direct hardware acceleration. This work re-evaluates the core FlashAttention kernel, presenting FLASH-D a mathematically equivalent, yet simplified, formulation that achieves: (a) hiding softmax division within other non-linear function evaluations; (b) inherently numerically stable computation of exponentials, eliminating the need for maximum value subtraction; and (c) a reduction in computational cost without introducing numerical approximations to the FlashAttention kernel. Importantly, the essential FlashAttention properties that facilitate efficient tiled implementation are fully preserved. Hardware implementation results at 28nm demonstrate that this proposed formulation achieves a 22.8% reduction in area and a 20.3% reduction in power, on average, compared to state-of-the-art parallel hardware architectures without any performance penalty.

Paper Structure

This paper contains 16 sections, 16 equations, 5 figures, 1 table, 3 algorithms.

Figures (5)

  • Figure 1: A parallel hardware architecture for FlashAttention2 kernel for multiple preloaded query vectors.
  • Figure 2: Weight $w_i$ function for various values of consecutive attention score differences $s_i-s_{i-1}$. The four weight graphs correspond to four different values of the previous weight $w_{i-1}$.
  • Figure 3: A parallel hardware architecture for FLASH-D kernel for multiple preloaded query vectors.
  • Figure 4: The hardware area at 28 nm for FLASH-D and FlashAttention2 kernel for computing attention of a single query using BFloat16 and FP8-E4M3 floating-point formats, across different hidden dimension lengths.
  • Figure 5: The average power for FLASH-D and FlashAttention2 kernel for computing attention of a single query using BFloat16 and FP8-E4M3 floating-point formats, across different hidden dimension lengths. Memory and IO power is not included since it is identical to both designs.