Table of Contents
Fetching ...

TiledAttention: a CUDA Tile SDPA Kernel for PyTorch

Taimur Khan

TL;DR

While production fused baselines remain stronger overall, TiledAttention delivers large speedups over standard eager attention paths and is available for direct use within PyTorch workflows, providing a practical balance between performance and customizability.

Abstract

TiledAttention is a scaled dot-product attention (SDPA) forward operator for SDPA research on NVIDIA GPUs. Implemented in cuTile Python (TileIR) and exposed as a PyTorch-callable function, it is easier to modify than low-level CUDA templates while retaining realistic behavior via online softmax and tiled $K,V$ streaming. The approach is both performant and directly editable at the schedule level from Python (tile shapes, staging, shared-memory layout), enabling rapid, reproducible kernel research without template-heavy CUDA/CUTLASS rewrites. We benchmark TiledAttention on an NVIDIA DGX GB10 node with a reproducible harness and compare against PyTorch SDPA (auto-dispatch) and explicit unfused baselines across sequence length, head dimension, and precision (FP16/BF16). While production fused baselines remain stronger overall, TiledAttention delivers large speedups over standard eager attention paths and is available for direct use within PyTorch workflows, providing a practical balance between performance and customizability.

TiledAttention: a CUDA Tile SDPA Kernel for PyTorch

TL;DR

While production fused baselines remain stronger overall, TiledAttention delivers large speedups over standard eager attention paths and is available for direct use within PyTorch workflows, providing a practical balance between performance and customizability.

Abstract

TiledAttention is a scaled dot-product attention (SDPA) forward operator for SDPA research on NVIDIA GPUs. Implemented in cuTile Python (TileIR) and exposed as a PyTorch-callable function, it is easier to modify than low-level CUDA templates while retaining realistic behavior via online softmax and tiled streaming. The approach is both performant and directly editable at the schedule level from Python (tile shapes, staging, shared-memory layout), enabling rapid, reproducible kernel research without template-heavy CUDA/CUTLASS rewrites. We benchmark TiledAttention on an NVIDIA DGX GB10 node with a reproducible harness and compare against PyTorch SDPA (auto-dispatch) and explicit unfused baselines across sequence length, head dimension, and precision (FP16/BF16). While production fused baselines remain stronger overall, TiledAttention delivers large speedups over standard eager attention paths and is available for direct use within PyTorch workflows, providing a practical balance between performance and customizability.
Paper Structure (30 sections, 4 equations, 6 figures, 7 tables)

This paper contains 30 sections, 4 equations, 6 figures, 7 tables.

Figures (6)

  • Figure 1: Conceptual view: as $S$ grows, SDPA increasingly dominates end-to-end throughput.
  • Figure 2: Figure 2: TiledAttention forward pipeline at a glance.
  • Figure 3: Throughput versus sequence length for $D{=}128$ (FP16/BF16, non-causal).
  • Figure 4: Explicit baseline comparison (FP16): TiledAttention vs fused PyTorch SDPA, math-SDPA, and eager attention.
  • Figure 5: Relative performance regime map (TiledAttention as % of fused PyTorch SDPA) for FP16 non-causal runs.
  • ...and 1 more figures