Table of Contents
Fetching ...

FlashFormer: Whole-Model Kernels for Efficient Low-Batch Inference

Aniruddha Nrusimha, William Brandon, Mayank Mishra, Yikang Shen, Rameswar Panda, Jonathan Ragan-Kelley, Yoon Kim

TL;DR

FlashFormer tackles the bottlenecks of low-batch transformer inference by fusing the entire forward pass into a single, highly specialized kernel. It employs metaprogramming (Cheetah), a unified memory pipeline, and cross-layer synchronization to overlap memory movement with computation, thereby amortizing kernel launch costs. Across Llama 3.1 configurations and quantizations, it achieves consistent speedups over established baselines, with notable gains for longer sequences and smaller models. The work demonstrates the practicality of whole-model kernel fusion for latency-sensitive deployments while acknowledging limitations to single-GPU, low-batch regimes and outlining paths for future multi-GPU and model-general extensions.

Abstract

The size and compute characteristics of modern large language models have led to an increased interest in developing specialized kernels tailored for particular training and inference workloads. Existing kernels primarily optimize for compute utilization, targeting the large-batch training and inference settings. However, low-batch inference, where memory bandwidth and kernel launch overheads are significant factors, remains important for many applications of interest such as in edge deployment and latency-sensitive applications. This paper describes FlashFormer, which fuses the entire transformer forward pass into a single kernel for accelerating low-batch inference of large language models. Across various model sizes and quantizations settings, FlashFormer achieves nontrivial speedups compared to existing inference kernels.

FlashFormer: Whole-Model Kernels for Efficient Low-Batch Inference

TL;DR

FlashFormer tackles the bottlenecks of low-batch transformer inference by fusing the entire forward pass into a single, highly specialized kernel. It employs metaprogramming (Cheetah), a unified memory pipeline, and cross-layer synchronization to overlap memory movement with computation, thereby amortizing kernel launch costs. Across Llama 3.1 configurations and quantizations, it achieves consistent speedups over established baselines, with notable gains for longer sequences and smaller models. The work demonstrates the practicality of whole-model kernel fusion for latency-sensitive deployments while acknowledging limitations to single-GPU, low-batch regimes and outlining paths for future multi-GPU and model-general extensions.

Abstract

The size and compute characteristics of modern large language models have led to an increased interest in developing specialized kernels tailored for particular training and inference workloads. Existing kernels primarily optimize for compute utilization, targeting the large-batch training and inference settings. However, low-batch inference, where memory bandwidth and kernel launch overheads are significant factors, remains important for many applications of interest such as in edge deployment and latency-sensitive applications. This paper describes FlashFormer, which fuses the entire transformer forward pass into a single kernel for accelerating low-batch inference of large language models. Across various model sizes and quantizations settings, FlashFormer achieves nontrivial speedups compared to existing inference kernels.

Paper Structure

This paper contains 35 sections, 8 figures, 10 tables.

Figures (8)

  • Figure 1: Depiction of the key benefit of FlashFormer. The top half depicts normal transformer execution with multiple kernels. After all chunks of Layer 1 have been loaded into the SM cache, we cannot start loading Layer 2 (top left). This, coupled with kernel overhead, leads to a break in memory loads and computation (top right). With FlashFormer, we fuse kernels to enable overlapping across prior kernel barriers (bottom left). This leads to more efficient overlapping and faster runtime (bottom right).
  • Figure 2: Memory bandwidth achieved by fused stacked linear layers. Cross layer overlapping refers to starting to load weights for future layers before computation for the current layer has finished.
  • Figure 3: Left: Work partitioning between and within thread groups. The work for the matrix in global memory is split between thread groups. The work for each thread group is split into chunks for the pipeline buffer. Within the chunk, the work is divided into chunks. Right: Thread Group A work partitioning. Producer warps manage the pipeline buffer asynchronously, while consumer warps work synchronously on sections of chunks.
  • Figure 4: Conceptual diagram of a standard attention forward pass and our method. Left: a standard attention forward pass in an inference pipeline, broken down into four kernels. Right: FlashFormer, where all operations are part of one kernel invocation.
  • Figure 5: The Attention forward pass per layer. Large matrices (green) are loaded asynchronously from global, and do not wait on warps performing computation. The consumers works synchronously, with 3 global synchronizations per attention forward pass. Each of the repeated computations (in red) corresponds to a asynchronously loaded matrix split into chunks.
  • ...and 3 more figures