Table of Contents
Fetching ...

Flashlight: PyTorch Compiler Extensions to Accelerate Attention Variants

Bozhi You, Irene Wang, Zelal Su Mustafaoglu, Abhinav Jangda, Angélica Moreira, Roshan Dathathri, Divya Mahajan, Keshav Pingali

TL;DR

Flashlight tackles the challenge of efficiently supporting diverse attention variants by integrating with the PyTorch 2.0 compiler to automatically generate fused, FlashAttention-like kernels from idiomatic PyTorch code. It introduces a unified reduction IR, algebraic transformations of reductions, and tiling-aware dimension elimination to fuse complex attention operations end-to-end, enabling end-to-end fused kernels without static templates. Compared with FlexAttention and vanilla torch.compile, Flashlight achieves competitive or superior kernel performance and enables rapid exploration of new attention models, including differential attention and Evoformer-style self-attention, with measurable improvements in end-to-end latency for AlphaFold Evoformer workloads. The approach leverages TorchDynamo/TorchInductor within the Torch 2.0 stack to deliver substantial inference speedups (and training gains) across diverse attention variants on modern GPUs, illustrating a practical path to broad, high-performance attention variants in PyTorch. Overall, Flashlight provides a general, compiler-native framework that unifies reduction-based fusion, tiling, and algebraic reasoning to push the performance envelope of arbitrary attention mechanisms while preserving PyTorch’s flexibility.

Abstract

Attention is a fundamental building block of large language models (LLMs), so there have been many efforts to implement it efficiently. For example, FlashAttention leverages tiling and kernel fusion to optimize attention. Recently, a number of variants of attention have been introduced to enhance model quality or efficiency. Supporting them efficiently remains difficult since they usually require specialized kernels or hand-tuned implementations. FlexAttention recently addressed part of this gap by using static programming templates to support FlashAttention-like kernels for a subset of attention variants. In this paper, we introduce Flashlight, a compiler-native framework within the PyTorch ecosystem that automatically generates fused, FlashAttention-style kernels for arbitrary attention-based programs, without relying on static templates or predefined kernel specializations. Flashlight leverages PyTorch's compilation workflow to fuse and tile attention computations transparently, enabling efficient execution for diverse attention patterns. Not only does it support all variants expressible in the FlexAttention model but it also handles more general, data-dependent attention formulations that are beyond the capabilities of FlexAttention. Our results show that Flashlight produces kernels with competitive or superior performance to FlexAttention, while offering the flexibility of native PyTorch code, enabling developers to rapidly explore new attention models without sacrificing performance.

Flashlight: PyTorch Compiler Extensions to Accelerate Attention Variants

TL;DR

Flashlight tackles the challenge of efficiently supporting diverse attention variants by integrating with the PyTorch 2.0 compiler to automatically generate fused, FlashAttention-like kernels from idiomatic PyTorch code. It introduces a unified reduction IR, algebraic transformations of reductions, and tiling-aware dimension elimination to fuse complex attention operations end-to-end, enabling end-to-end fused kernels without static templates. Compared with FlexAttention and vanilla torch.compile, Flashlight achieves competitive or superior kernel performance and enables rapid exploration of new attention models, including differential attention and Evoformer-style self-attention, with measurable improvements in end-to-end latency for AlphaFold Evoformer workloads. The approach leverages TorchDynamo/TorchInductor within the Torch 2.0 stack to deliver substantial inference speedups (and training gains) across diverse attention variants on modern GPUs, illustrating a practical path to broad, high-performance attention variants in PyTorch. Overall, Flashlight provides a general, compiler-native framework that unifies reduction-based fusion, tiling, and algebraic reasoning to push the performance envelope of arbitrary attention mechanisms while preserving PyTorch’s flexibility.

Abstract

Attention is a fundamental building block of large language models (LLMs), so there have been many efforts to implement it efficiently. For example, FlashAttention leverages tiling and kernel fusion to optimize attention. Recently, a number of variants of attention have been introduced to enhance model quality or efficiency. Supporting them efficiently remains difficult since they usually require specialized kernels or hand-tuned implementations. FlexAttention recently addressed part of this gap by using static programming templates to support FlashAttention-like kernels for a subset of attention variants. In this paper, we introduce Flashlight, a compiler-native framework within the PyTorch ecosystem that automatically generates fused, FlashAttention-style kernels for arbitrary attention-based programs, without relying on static templates or predefined kernel specializations. Flashlight leverages PyTorch's compilation workflow to fuse and tile attention computations transparently, enabling efficient execution for diverse attention patterns. Not only does it support all variants expressible in the FlexAttention model but it also handles more general, data-dependent attention formulations that are beyond the capabilities of FlexAttention. Our results show that Flashlight produces kernels with competitive or superior performance to FlexAttention, while offering the flexibility of native PyTorch code, enabling developers to rapidly explore new attention models without sacrificing performance.

Paper Structure

This paper contains 20 sections, 12 equations, 4 figures, 2 algorithms.

Figures (4)

  • Figure 1: FlashLight extends TorchInductor within the torch.compile stack, adding structural and semantic fusion passes with dimension demotion, algebraic transformation, and tiling-aware dimension elimination to generate optimized Triton kernels.
  • Figure 2: Runtimes of FlashLight and FlexAttention on H100 for attention variants that are supported by FlexAttention template.
  • Figure 3: Runtimes of FlashLight and FlexAttention on A100 for attention variants that are supported by FlexAttention template.
  • Figure 4: Runtimes of FlashLight and torch.compile on H100/A100 for attention variants that are not supported by FlexAttention.

Theorems & Definitions (1)

  • Definition 1