Table of Contents
Fetching ...

FlashFFTConv: Efficient Convolutions for Long Sequences with Tensor Cores

Daniel Y. Fu, Hermann Kumbong, Eric Nguyen, Christopher Ré

TL;DR

This work tackles the inefficiency of FFT-based long-sequence convolutions on modern accelerators by introducing FlashFFTConv, which reexpresses FFTs via a Monarch order-p decomposition mapped to matrix multiplies and enables end-to-end kernel fusion. It further adds domain-specific optimizations and sparsity-inspired extensions (partial and frequency-sparse convolutions) to reduce I/O and memory, achieving substantial speedups (up to 7.93× over PyTorch FFT) and memory savings. The approach unlocks longer-context capabilities, demonstrated by enabling Path-512 image tasks and 4M-length DNA sequences, while delivering quality gains comparable to larger models under the same compute budget. The results suggest a practical path toward hardware-efficient convolutional sequence models that can rival Transformers in throughput and enable new applications in biology and high-resolution vision.

Abstract

Convolution models with long filters have demonstrated state-of-the-art reasoning abilities in many long-sequence tasks but lag behind the most optimized Transformers in wall-clock time. A major bottleneck is the Fast Fourier Transform (FFT)--which allows long convolutions to run in $O(N logN)$ time in sequence length $N$ but has poor hardware utilization. In this paper, we study how to optimize the FFT convolution. We find two key bottlenecks: the FFT does not effectively use specialized matrix multiply units, and it incurs expensive I/O between layers of the memory hierarchy. In response, we propose FlashFFTConv. FlashFFTConv uses a matrix decomposition that computes the FFT using matrix multiply units and enables kernel fusion for long sequences, reducing I/O. We also present two sparse convolution algorithms--1) partial convolutions and 2) frequency-sparse convolutions--which can be implemented simply by skipping blocks in the matrix decomposition, enabling further opportunities for memory and compute savings. FlashFFTConv speeds up exact FFT convolutions by up to 7.93$\times$ over PyTorch and achieves up to 4.4$\times$ speedup end-to-end. Given the same compute budget, FlashFFTConv allows Hyena-GPT-s to achieve 2.3 points better perplexity on the PILE and M2-BERT-base to achieve 3.3 points higher GLUE score--matching models with twice the parameter count. FlashFFTConv also achieves 96.1% accuracy on Path-512, a high-resolution vision task where no model had previously achieved better than 50%. Furthermore, partial convolutions enable longer-sequence models--yielding the first DNA model that can process the longest human genes (2.3M base pairs)--and frequency-sparse convolutions speed up pretrained models while maintaining or improving model quality.

FlashFFTConv: Efficient Convolutions for Long Sequences with Tensor Cores

TL;DR

This work tackles the inefficiency of FFT-based long-sequence convolutions on modern accelerators by introducing FlashFFTConv, which reexpresses FFTs via a Monarch order-p decomposition mapped to matrix multiplies and enables end-to-end kernel fusion. It further adds domain-specific optimizations and sparsity-inspired extensions (partial and frequency-sparse convolutions) to reduce I/O and memory, achieving substantial speedups (up to 7.93× over PyTorch FFT) and memory savings. The approach unlocks longer-context capabilities, demonstrated by enabling Path-512 image tasks and 4M-length DNA sequences, while delivering quality gains comparable to larger models under the same compute budget. The results suggest a practical path toward hardware-efficient convolutional sequence models that can rival Transformers in throughput and enable new applications in biology and high-resolution vision.

Abstract

Convolution models with long filters have demonstrated state-of-the-art reasoning abilities in many long-sequence tasks but lag behind the most optimized Transformers in wall-clock time. A major bottleneck is the Fast Fourier Transform (FFT)--which allows long convolutions to run in time in sequence length but has poor hardware utilization. In this paper, we study how to optimize the FFT convolution. We find two key bottlenecks: the FFT does not effectively use specialized matrix multiply units, and it incurs expensive I/O between layers of the memory hierarchy. In response, we propose FlashFFTConv. FlashFFTConv uses a matrix decomposition that computes the FFT using matrix multiply units and enables kernel fusion for long sequences, reducing I/O. We also present two sparse convolution algorithms--1) partial convolutions and 2) frequency-sparse convolutions--which can be implemented simply by skipping blocks in the matrix decomposition, enabling further opportunities for memory and compute savings. FlashFFTConv speeds up exact FFT convolutions by up to 7.93 over PyTorch and achieves up to 4.4 speedup end-to-end. Given the same compute budget, FlashFFTConv allows Hyena-GPT-s to achieve 2.3 points better perplexity on the PILE and M2-BERT-base to achieve 3.3 points higher GLUE score--matching models with twice the parameter count. FlashFFTConv also achieves 96.1% accuracy on Path-512, a high-resolution vision task where no model had previously achieved better than 50%. Furthermore, partial convolutions enable longer-sequence models--yielding the first DNA model that can process the longest human genes (2.3M base pairs)--and frequency-sparse convolutions speed up pretrained models while maintaining or improving model quality.
Paper Structure (65 sections, 12 equations, 5 figures, 19 tables, 4 algorithms)

This paper contains 65 sections, 12 equations, 5 figures, 19 tables, 4 algorithms.

Figures (5)

  • Figure 1: Left: GPU memory hierarchy. Middle left: Order-$p$ Monarch decomposition of FFT convolution, with $p=2$. Middle right: Kernel fusion for end-to-end speedup. Right:FlashFFTConv introduces analogues of sparsity for convolutions.
  • Figure 2: Illustration of Monarch FFT decomposition.
  • Figure 3: Top:FlashFFTConv adapts the Monarch FFT decomposition to broadcast matrix multiply operations over the sequence instead of over the batch and hidden dimensions. Bottom: This converts HBM permutations simple matrix transpose operations in SRAM.
  • Figure 4: Compute costs of different order-$p$ Monarch decompositions as sequence length increases on A100. Tradeoff points correspond to when the matrices in the Monarch decomposition reach the size of tensor cores on A100 and when the sequence becomes too long for SRAM.
  • Figure 5: t-SNE visualization of various genes and DNA segments using our new HyenaDNA-4M. The longest human gene, Dystrophin, is annotated.