FLASH-D: FlashAttention with Hidden Softmax Division
Kosmas Alexandridis, Vasileios Titopoulos, Giorgos Dimitrakopoulos
TL;DR
FLASH-D re-expresses the FlashAttention kernel to hide the softmax division inside sigmoid-based weight computations, preserving numerical stability and the benefits of tiled, online softmax. The approach maintains the same attention semantics while removing explicit division and running-sum tracking, enabling simpler and more area- and power-efficient hardware implementations. Hardware evaluations on 28 nm ASIC show substantial area and power reductions relative to FlashAttention2 without sacrificing accuracy or latency. The method also enables potential computation skipping in inference, offering practical gains for large language models without introducing approximations. Overall, FLASH-D provides a hardware-friendly reformulation that preserves core attention properties while improving efficiency for long-sequence transformers.
Abstract
The transformer's attention mechanism has revolutionized AI and machine learning, with its efficient computation being crucial to its performance. However, calculating attention involves matrix operations interspersed with softmax rescaling, which inherently slows down computation and requires processing the entire input sequence. Building on online softmax computation, FlashAttention integrates softmax calculation with matrix arithmetic, enabling tiled computation independent of sequence length. While optimized for GPUs, FlashAttention's simplicity makes it amenable to direct hardware acceleration. This work re-evaluates the core FlashAttention kernel, presenting FLASH-D a mathematically equivalent, yet simplified, formulation that achieves: (a) hiding softmax division within other non-linear function evaluations; (b) inherently numerically stable computation of exponentials, eliminating the need for maximum value subtraction; and (c) a reduction in computational cost without introducing numerical approximations to the FlashAttention kernel. Importantly, the essential FlashAttention properties that facilitate efficient tiled implementation are fully preserved. Hardware implementation results at 28nm demonstrate that this proposed formulation achieves a 22.8% reduction in area and a 20.3% reduction in power, on average, compared to state-of-the-art parallel hardware architectures without any performance penalty.
