FlashFormer: Whole-Model Kernels for Efficient Low-Batch Inference
Aniruddha Nrusimha, William Brandon, Mayank Mishra, Yikang Shen, Rameswar Panda, Jonathan Ragan-Kelley, Yoon Kim
TL;DR
FlashFormer tackles the bottlenecks of low-batch transformer inference by fusing the entire forward pass into a single, highly specialized kernel. It employs metaprogramming (Cheetah), a unified memory pipeline, and cross-layer synchronization to overlap memory movement with computation, thereby amortizing kernel launch costs. Across Llama 3.1 configurations and quantizations, it achieves consistent speedups over established baselines, with notable gains for longer sequences and smaller models. The work demonstrates the practicality of whole-model kernel fusion for latency-sensitive deployments while acknowledging limitations to single-GPU, low-batch regimes and outlining paths for future multi-GPU and model-general extensions.
Abstract
The size and compute characteristics of modern large language models have led to an increased interest in developing specialized kernels tailored for particular training and inference workloads. Existing kernels primarily optimize for compute utilization, targeting the large-batch training and inference settings. However, low-batch inference, where memory bandwidth and kernel launch overheads are significant factors, remains important for many applications of interest such as in edge deployment and latency-sensitive applications. This paper describes FlashFormer, which fuses the entire transformer forward pass into a single kernel for accelerating low-batch inference of large language models. Across various model sizes and quantizations settings, FlashFormer achieves nontrivial speedups compared to existing inference kernels.
