Table of Contents
Fetching ...

Flash Inference: Near Linear Time Inference for Long Convolution Sequence Models and Beyond

Costin-Andrei Oncescu, Sanket Purandare, Stratos Idreos, Sham Kakade

TL;DR

This work tackles the quadratic inference cost of transformers for long contexts by introducing Flash Inference, a tiling-based framework that enables near-linear time exact inference for Long Convolution Sequence Models. The core idea is to replace naive autoregressive updates with fast relaxed polynomial interpolation, organized into tiles, which allows $O(MDL\log^2 L)$ FLOPs and substantial reductions in memory movement. The framework also enables across-layer parallelization and yields large practical gains, demonstrated on Hyena with up to $7.8\times$ end-to-end speedups and $110\times$ mixer-speedups. By abstracting the approach into architectural properties and the A/\mathcal{T} machinery, the authors provide a general pathway to accelerate a broad class of causal, convolution-based sequence models, with potential extensions to data-dependent filters and other architectures.

Abstract

While transformers have been at the core of most recent advancements in sequence generative models, their computational cost remains quadratic in sequence length. Several subquadratic architectures have been proposed to address this computational issue. Some of them, including long convolution sequence models (LCSMs), such as Hyena, address this issue at training time but remain quadratic during inference. We propose a method for speeding up LCSMs' exact inference to quasilinear $O(L\log^2L)$ time, identify the key properties that make this possible, and propose a general framework that exploits these. Our approach, inspired by previous work on relaxed polynomial interpolation, is based on a tiling which helps decrease memory movement and share computation. It has the added benefit of allowing for almost complete parallelization across layers of the position-mixing part of the architecture. Empirically, we provide a proof of concept implementation for Hyena, which gets up to $7.8\times$ end-to-end improvement over standard inference by improving $110\times$ within the position-mixing part.

Flash Inference: Near Linear Time Inference for Long Convolution Sequence Models and Beyond

TL;DR

This work tackles the quadratic inference cost of transformers for long contexts by introducing Flash Inference, a tiling-based framework that enables near-linear time exact inference for Long Convolution Sequence Models. The core idea is to replace naive autoregressive updates with fast relaxed polynomial interpolation, organized into tiles, which allows FLOPs and substantial reductions in memory movement. The framework also enables across-layer parallelization and yields large practical gains, demonstrated on Hyena with up to end-to-end speedups and mixer-speedups. By abstracting the approach into architectural properties and the A/\mathcal{T} machinery, the authors provide a general pathway to accelerate a broad class of causal, convolution-based sequence models, with potential extensions to data-dependent filters and other architectures.

Abstract

While transformers have been at the core of most recent advancements in sequence generative models, their computational cost remains quadratic in sequence length. Several subquadratic architectures have been proposed to address this computational issue. Some of them, including long convolution sequence models (LCSMs), such as Hyena, address this issue at training time but remain quadratic during inference. We propose a method for speeding up LCSMs' exact inference to quasilinear time, identify the key properties that make this possible, and propose a general framework that exploits these. Our approach, inspired by previous work on relaxed polynomial interpolation, is based on a tiling which helps decrease memory movement and share computation. It has the added benefit of allowing for almost complete parallelization across layers of the position-mixing part of the architecture. Empirically, we provide a proof of concept implementation for Hyena, which gets up to end-to-end improvement over standard inference by improving within the position-mixing part.

Paper Structure

This paper contains 39 sections, 5 theorems, 12 equations, 5 figures, 4 tables, 4 algorithms.

Key Result

Lemma 1

Let $1 \leq l \leq r \leq l' \leq r' \leq L$ represent ranges of lengths $L_1=r-l+1$ and $L_2=r'-l'+1$ of $y$ and $z$, respectively. There exists an FFT-based algorithm running in $O(L_1+L_2)$ space and $O((L_1+L_2)\log(L_1+L_2))$ time complexity that, given access to $y_{[l, r]}$, computes all the

Figures (5)

  • Figure 1: Cell ($i$, $j$) corresponds to the contribution of mixer-input $y_i$ to mixer-output $z_j$. To compute $z_j$, all its line of contributions should be be accounted for. Because of the autoregressive nature of inference, one only has access to $y_i$ after $z_{i-1}$ has been computed. (Left Top) represents the standard (lazy) approach, (Left Bottom) represents the eager approach, and (Right) represents our suggested method.
  • Figure 2: Real world Hyena experiments: (a) End-to-end inference time breakdown shows Hybird provides 4.8$\times$ speed-up over optimized baselines (b) Cumulative mixer time of Hybrid scales 90$\times$ better (c) Hybrid shows low variance in per-token response time except at the tokens positions where large tiles are computed.
  • Figure 3: Mixer Isolation in a Synthetic setting: (a) Different implementations of $\tau$ are optimal for different tile sizes creating a pareto optimal curve for Hybrid to choose, (b) Cumulative mixer inference of Hybrid achieves the best of all $\tau$ Implementations (c) End-to-end cumulative token inference breakdown
  • Figure 4: Time breakdown for end-to-end Hyena experiments.
  • Figure 5: Time breakdown for end-to-end synthetic experiments.

Theorems & Definitions (8)

  • Definition
  • Lemma 1
  • Proposition 1
  • Proposition 2
  • Theorem 2
  • proof
  • Lemma
  • proof