Table of Contents
Fetching ...

Spark Transformer: Reactivating Sparsity in FFN and Attention

Chong You, Kan Wu, Zhipeng Jia, Lin Chen, Srinadh Bhojanapalli, Jiaxian Guo, Utku Evci, Jan Wassenberg, Praneeth Netrapalli, Jeremiah J. Willcock, Suvinay Subramanian, Felix Chern, Alek Andreev, Shreya Pathak, Felix Yu, Prateek Jain, David E. Culler, Henry M. Levy, Sanjiv Kumar

TL;DR

Spark Transformer reactivates activation sparsity in modern Transformers by enforcing sparse activations in both FFN and attention with a linear-time Statistical-Top_k operator and a low-cost predictor derived from a subset of Q/K, achieving substantial FLOPs reductions while maintaining near-peak quality. The approach yields only about $8\%$ of FFN neurons active and up to $256$ attended tokens per token, translating to a roughly $2.5\times$ reduction in FLOPs and wall-time speedups up to $1.79\times$ on CPU and $1.40\times$ on GPU, demonstrated on Gemma-2 with a single-stage training. Theoretical and practical contributions include threshold-estimation guarantees for Statistical-Top_k, differentiable sparsification via soft-thresholding, and hardware-aware sparse matrix multiplications that preserve training simplicity. The work also discusses potential synergies with speculative decoding and quantization, and frames Spark Transformer as a pathway toward more efficient, sparsity-enabled inference for large-scale language models.

Abstract

The discovery of the lazy neuron phenomenon in trained Transformers, where the vast majority of neurons in their feed-forward networks (FFN) are inactive for each token, has spurred tremendous interests in activation sparsity for enhancing large model efficiency. While notable progress has been made in translating such sparsity to wall-time benefits, modern Transformers have moved away from the ReLU activation function crucial to this phenomenon. Existing efforts on re-introducing activation sparsity often degrade model quality, increase parameter count, complicate or slow down training. Sparse attention, the application of sparse activation to the attention mechanism, often faces similar challenges. This paper introduces the Spark Transformer, a novel architecture that achieves a high level of activation sparsity in both FFN and the attention mechanism while maintaining model quality, parameter count, and standard training procedures. Our method realizes sparsity via top-k masking for explicit control over sparsity level. Crucially, we introduce statistical top-k, a hardware-accelerator-friendly, linear-time approximate algorithm that avoids costly sorting and mitigates significant training slowdown from standard top-$k$ operators. Furthermore, Spark Transformer reallocates existing FFN parameters and attention key embeddings to form a low-cost predictor for identifying activated entries. This design not only mitigates quality loss from enforced sparsity, but also enhances wall-time benefit. Pretrained with the Gemma-2 recipe, Spark Transformer demonstrates competitive performance on standard benchmarks while exhibiting significant sparsity: only 8% of FFN neurons are activated, and each token attends to a maximum of 256 tokens. This sparsity translates to a 2.5x reduction in FLOPs, leading to decoding wall-time speedups of up to 1.79x on CPU and 1.40x on GPU.

Spark Transformer: Reactivating Sparsity in FFN and Attention

TL;DR

Spark Transformer reactivates activation sparsity in modern Transformers by enforcing sparse activations in both FFN and attention with a linear-time Statistical-Top_k operator and a low-cost predictor derived from a subset of Q/K, achieving substantial FLOPs reductions while maintaining near-peak quality. The approach yields only about of FFN neurons active and up to attended tokens per token, translating to a roughly reduction in FLOPs and wall-time speedups up to on CPU and on GPU, demonstrated on Gemma-2 with a single-stage training. Theoretical and practical contributions include threshold-estimation guarantees for Statistical-Top_k, differentiable sparsification via soft-thresholding, and hardware-aware sparse matrix multiplications that preserve training simplicity. The work also discusses potential synergies with speculative decoding and quantization, and frames Spark Transformer as a pathway toward more efficient, sparsity-enabled inference for large-scale language models.

Abstract

The discovery of the lazy neuron phenomenon in trained Transformers, where the vast majority of neurons in their feed-forward networks (FFN) are inactive for each token, has spurred tremendous interests in activation sparsity for enhancing large model efficiency. While notable progress has been made in translating such sparsity to wall-time benefits, modern Transformers have moved away from the ReLU activation function crucial to this phenomenon. Existing efforts on re-introducing activation sparsity often degrade model quality, increase parameter count, complicate or slow down training. Sparse attention, the application of sparse activation to the attention mechanism, often faces similar challenges. This paper introduces the Spark Transformer, a novel architecture that achieves a high level of activation sparsity in both FFN and the attention mechanism while maintaining model quality, parameter count, and standard training procedures. Our method realizes sparsity via top-k masking for explicit control over sparsity level. Crucially, we introduce statistical top-k, a hardware-accelerator-friendly, linear-time approximate algorithm that avoids costly sorting and mitigates significant training slowdown from standard top- operators. Furthermore, Spark Transformer reallocates existing FFN parameters and attention key embeddings to form a low-cost predictor for identifying activated entries. This design not only mitigates quality loss from enforced sparsity, but also enhances wall-time benefit. Pretrained with the Gemma-2 recipe, Spark Transformer demonstrates competitive performance on standard benchmarks while exhibiting significant sparsity: only 8% of FFN neurons are activated, and each token attends to a maximum of 256 tokens. This sparsity translates to a 2.5x reduction in FLOPs, leading to decoding wall-time speedups of up to 1.79x on CPU and 1.40x on GPU.

Paper Structure

This paper contains 36 sections, 2 theorems, 42 equations, 12 figures, 2 tables.

Key Result

Theorem 3.1

Let ${\bm{x}}\in\mathbb{R}^{d}$ be a vector with entries drawn i.i.d. from $\mathcal{N}(\mu,\sigma^{2})$. For any $1\le k\le d-1$, let $\theta({\bm{x}},k)$ be a scalar defined in eq:statistical-topk. Take any $\delta \in (0, 1)$ and assume $d\ge\max\{2,\log\frac{6}{\delta}\}$. With a probability of

Figures (12)

  • Figure 1: Spark Transformer improves inference efficiency via high-level activation sparsity in both FFN and attention, while being nearly quality neutral. (a) Comparison to prior work in terms of relative FLOPs per token at 8k sequence length (y-axis) vs relative training loss (x-axis). [$\blacksquare$] We use standard Gemma-2 team2024gemma as baseline, which has no activation sparsity. [•] Methods employing activation sparsity within the FFN layers only. Our Spark FFN achieves the most favorable trade-off compared to ReLU, ReLU$^2$, and Topk, which refer to standard Gemma-2 with activation function switched to ReLU mirzadeh2023relu, ReLU$^2$zhang2024relu, and the composition of Topk and GELU, respectively. [$\blacktriangle$] Combining Spark FFN (with 8% activated parameters) with Spark Attention (with at most $256$ attended tokens), our Spark Transformer achieves performance comparable to Gemma-2 while reducing FLOPs to 40%.(b) Evaluation on standard downstream tasks confirms near-quality neutrality of Spark Transformer. (c) Prefill / decode wall time demonstrate a 1.86$\times$/1.64$\times$ speedup resulting from FLOPs reduction. Results are obtained on a 4-Core CPUs for prompts of 4096 tokens. For prefill, the prompt is chunked into batches of 64 tokens, following a default setup of gemma.cppgemmacpp.
  • Figure 2: Architecture of Spark FFN and Spark Attention. (Left) Unified illustration of standard FFN (i.e.,\ref{['eq:ffn']}) and standard Attention (i.e.,\ref{['eq:attention']}). In the case of FFN, ${\bm{q}} \in \mathbb{R}^{d_\text{model}}$ is the input, ${\bm{K}}$ and ${\bm{V}}$ are the first and second layer weights, respectively, and $\sigma()$ is GELU. In the case of Attention, ${\bm{q}} \in \mathbb{R}^{d_\text{attn}}$ is the query, ${\bm{K}}$ and ${\bm{V}}$ are key and value matrices, respectively, and $\sigma()$ is softmax. (Right) Unified illustration of Spark FFN (i.e., \ref{['eq:lazy-ffn']}) and Spark Attention (i.e., \ref{['eq:lazy-attention']}). In the case of Spark FFN, $\sigma_1()$ is GELU and $\sigma_2()$ is identity. In the case of Spark Attention, $\sigma_1()$ is softmax and $\sigma_2()$ is softplus. In both cases, $\operatorname{{Statistical-Top}}_k$ (i.e.,\ref{['eq:statistical-topk']}) is applied to introduce sparsity, which enables sparse matrix multiplication with ${\bm{K}}_2$ and ${\bm{V}}$ that reduces FLOPs count.
  • Figure 3: Sparsity in the intermediate activation of Spark FFN and Spark Attention across $26$ layers at selected training steps. For FFN we report the percentage of nonzero entries out of $d_\text{ff}=13824$ entries. For Attention, we report the number of nonzero entries (i.e., attended tokens). Our hyper-parameter choice is to have 8% nonzeros in Spark FFN and at most 256 nonzeros in Spark Attention.
  • Figure 4: Illustration of the matrix multiplication implementation using sparse activation. (a) Vector-Masked Matrix Multiplication takes a dense vector ${\bm{q}}[r\!:]$, a dense matrix ${\bm{K}}_2^\top$, and a mask from statistical top-$k$ on ${\bm{K}}_1^\top {\bm{q}}[:\!r]$ to compute ${\bm{u}} := ({\bm{K}}_2^\top {\bm{q}}[r\!:]) \,\odot\,$mask. It skips memory loading and compute associated with the masked columns. (b) Sparse Vector-Matrix Multiplication takes a sparse activation vector ${\bm{u}}$ to compute weighted sum of rows in the dense matrix ${\bm{V}}$. It skips loading and computation of rows corresponding to 0's in ${\bm{u}}$. To optimize performance, we implement Sparse Vector-Matrix Multiplication using tiling, which helps minimize cross-CPU core synchronization.
  • Figure 5: Spark Transformer decoding speedup from activation sparsity on various hardware platforms. We report decoding speed of Spark Transformer without handware optimization for sparse activation, with hardware optimization for sparsity in Spark FFN only, and with hardware optimization for sparsity in both Spark FFN and Spark Attention. All experiments use a decode batch size of 1.
  • ...and 7 more figures

Theorems & Definitions (4)

  • Theorem 3.1
  • Theorem 3.2
  • proof
  • proof