Flash Window Attention: speedup the attention computation for Swin Transformer
Zhendong Zhang
TL;DR
This work tackles the high computational cost of attention in high‑resolution vision models by combining Swin Transformer window attention with a tailored flash‑style scheme. It introduces Flash Window Attention, which tiles attention along the feature dimension and stores the entire short attention matrix on‑chip, enabling efficient forward and backward passes. The approach achieves up to 300% speedup in attention computation and about 30% end‑to‑end speedup on GPUs, demonstrated with a Triton/PyTorch implementation on an RTX 4090. While effective for typical window sizes, it acknowledges limitations for very large windows and points to future work on broader window patterns.
Abstract
To address the high resolution of image pixels, the Swin Transformer introduces window attention. This mechanism divides an image into non-overlapping windows and restricts attention computation to within each window, significantly enhancing computational efficiency. To further optimize this process, one might consider replacing standard attention with flash attention, which has proven to be more efficient in language models. However, a direct substitution is ineffective. Flash attention is designed for long sequences, whereas window attention deals with shorter sequences but must handle numerous of them in parallel. In this report, we present an optimized solution called Flash Window Attention, tailored specifically for window attention. Flash Window Attention improves attention computation efficiency by up to 300% and enhances end-to-end runtime efficiency by up to 30%. Our code is available online.
