Table of Contents
Fetching ...

Deja Vu: Contextual Sparsity for Efficient LLMs at Inference Time

Zichang Liu, Jue Wang, Tri Dao, Tianyi Zhou, Binhang Yuan, Zhao Song, Anshumali Shrivastava, Ce Zhang, Yuandong Tian, Christopher Re, Beidi Chen

TL;DR

This paper tackles the high inference cost of large language models by introducing contextual sparsity, where input-dependent subsets of attention heads and MLP neurons are used per token. It develops DejaVu, a low-cost, asynchronous system that predicts sparsity on the fly and implements hardware-aware sparse computations to achieve substantial wall-clock speedups without retraining or sacrificing in-context learning. Key findings show that contextual sparsity is widespread (averaging around 85%), can be accurately predicted via lightweight models, and yields end-to-end latency reductions of over 2× versus state-of-the-art frameworks and up to 6× versus common implementations on OPT-175B, all with negligible loss in quality. The work demonstrates that combining near-neighbor style sparsity prediction with kernel-fused, memory-coalesced sparse matmuls can make LLMs more practical for latency-sensitive applications while remaining compatible with quantization techniques.

Abstract

Large language models (LLMs) with hundreds of billions of parameters have sparked a new wave of exciting AI applications. However, they are computationally expensive at inference time. Sparsity is a natural approach to reduce this cost, but existing methods either require costly retraining, have to forgo LLM's in-context learning ability, or do not yield wall-clock time speedup on modern hardware. We hypothesize that contextual sparsity, which are small, input-dependent sets of attention heads and MLP parameters that yield approximately the same output as the dense model for a given input, can address these issues. We show that contextual sparsity exists, that it can be accurately predicted, and that we can exploit it to speed up LLM inference in wall-clock time without compromising LLM's quality or in-context learning ability. Based on these insights, we propose DejaVu, a system that uses a low-cost algorithm to predict contextual sparsity on the fly given inputs to each layer, along with an asynchronous and hardware-aware implementation that speeds up LLM inference. We validate that DejaVu can reduce the inference latency of OPT-175B by over 2X compared to the state-of-the-art FasterTransformer, and over 6X compared to the widely used Hugging Face implementation, without compromising model quality. The code is available at https://github.com/FMInference/DejaVu.

Deja Vu: Contextual Sparsity for Efficient LLMs at Inference Time

TL;DR

This paper tackles the high inference cost of large language models by introducing contextual sparsity, where input-dependent subsets of attention heads and MLP neurons are used per token. It develops DejaVu, a low-cost, asynchronous system that predicts sparsity on the fly and implements hardware-aware sparse computations to achieve substantial wall-clock speedups without retraining or sacrificing in-context learning. Key findings show that contextual sparsity is widespread (averaging around 85%), can be accurately predicted via lightweight models, and yields end-to-end latency reductions of over 2× versus state-of-the-art frameworks and up to 6× versus common implementations on OPT-175B, all with negligible loss in quality. The work demonstrates that combining near-neighbor style sparsity prediction with kernel-fused, memory-coalesced sparse matmuls can make LLMs more practical for latency-sensitive applications while remaining compatible with quantization techniques.

Abstract

Large language models (LLMs) with hundreds of billions of parameters have sparked a new wave of exciting AI applications. However, they are computationally expensive at inference time. Sparsity is a natural approach to reduce this cost, but existing methods either require costly retraining, have to forgo LLM's in-context learning ability, or do not yield wall-clock time speedup on modern hardware. We hypothesize that contextual sparsity, which are small, input-dependent sets of attention heads and MLP parameters that yield approximately the same output as the dense model for a given input, can address these issues. We show that contextual sparsity exists, that it can be accurately predicted, and that we can exploit it to speed up LLM inference in wall-clock time without compromising LLM's quality or in-context learning ability. Based on these insights, we propose DejaVu, a system that uses a low-cost algorithm to predict contextual sparsity on the fly given inputs to each layer, along with an asynchronous and hardware-aware implementation that speeds up LLM inference. We validate that DejaVu can reduce the inference latency of OPT-175B by over 2X compared to the state-of-the-art FasterTransformer, and over 6X compared to the widely used Hugging Face implementation, without compromising model quality. The code is available at https://github.com/FMInference/DejaVu.
Paper Structure (51 sections, 19 theorems, 102 equations, 14 figures, 9 tables, 1 algorithm)

This paper contains 51 sections, 19 theorems, 102 equations, 14 figures, 9 tables, 1 algorithm.

Key Result

Lemma 3.1

Let $0 < \epsilon_1 < \epsilon_2< 1$ be the lower and upper bound of the shrinking factor. Let $x$ be the $y$ be the output. We have the residual connection $y = x + F(x)$. For the MLP block $F(x)$, we have $\epsilon_1 \leq \| y - x \|_2 \leq \epsilon_2$. For the attention block $F(x)$, we have $\ep

Figures (14)

  • Figure 1: (1) LLMs have up to 85% contextual sparsity for a given input. (2) Contextual sparsity has much better efficiency-accuracy trade-offs (up to 7$\times$) than non-contextual sparsity or static sparsity.
  • Figure 2: dejavu uses lookahead predictors to side-step prediction costs: given the input to the attention layer at block $k$, they (asynchronously) predict the contextual sparsity for the MLP at block $k$, and given the input to the MLP at block $k$, they predict the sparsity for the attention head at the next layer.
  • Figure 3: In Figure (a), we plot the percentage of not-activated attention heads. By only keeping heads that yield large output norms, we can silence over 80% attention heads for a given token. In Figure (b), we plot the average sparsity we impose on MLP layers. We can zero out over 95% of MLP parameters for a given token.
  • Figure 4: We visualize the attention scores of three different heads for an exemplary sentence. Head 42 and Head 44 give heavy attention scores on particular tokens while Head 43 is more uniform.
  • Figure 5: Slowly Changing Embedding. Figure (a) shows the median cosine similarity between representations at two consecutive layers across all layers for different OPT models. All models show a similarity greater than 95%. Figure (b) shows cosine similarity stays high even a few layers apart. For the residual connection $X' = X + F(X)$ inside each block, we plot the $\ell_2$ norm of $X$ and $F(X)$ in Figure (c) and Figure (d). $\|X\|$ is significantly higher than $\|F(X)\|$, which explains the slowly changing embedding.
  • ...and 9 more figures

Theorems & Definitions (53)

  • Lemma 3.1: Informal
  • Definition 4.1: Approximate $\mathsf{MaxIP}$ in MLP
  • Remark 4.2
  • Lemma 4.3: Informal
  • Definition 7.1
  • Lemma 7.2
  • Remark 7.3
  • proof
  • Lemma 7.4
  • proof
  • ...and 43 more