Table of Contents
Fetching ...

SpAtten: Efficient Sparse Attention Architecture with Cascade Token and Head Pruning

Hanrui Wang, Zhekai Zhang, Song Han

TL;DR

SpAtten tackles the attention bottleneck by co-designing algorithms and hardware to exploit activation sparsity and precision granularity. It introduces cascade token pruning, cascade head pruning, and progressive quantization, plus a high-throughput top-k engine to enable on-the-fly pruning without retraining. The approach yields dramatic DRAM traffic reductions and large speedups across BERT and GPT-2 workloads on diverse platforms, while preserving accuracy. This work demonstrates practical pathways to memory-bound NLP inference acceleration through targeted pruning and adaptive quantization in both software and hardware.

Abstract

The attention mechanism is becoming increasingly popular in Natural Language Processing (NLP) applications, showing superior performance than convolutional and recurrent architectures. However, attention becomes the compution bottleneck because of its quadratic computational complexity to input length, complicated data movement and low arithmetic intensity. Moreover, existing NN accelerators mainly focus on optimizing convolutional or recurrent models, and cannot efficiently support attention. In this paper, we present SpAtten, an efficient algorithm-architecture co-design that leverages token sparsity, head sparsity, and quantization opportunities to reduce the attention computation and memory access. Inspired by the high redundancy of human languages, we propose the novel cascade token pruning to prune away unimportant tokens in the sentence. We also propose cascade head pruning to remove unessential heads. Cascade pruning is fundamentally different from weight pruning since there is no trainable weight in the attention mechanism, and the pruned tokens and heads are selected on the fly. To efficiently support them on hardware, we design a novel top-k engine to rank token and head importance scores with high throughput. Furthermore, we propose progressive quantization that first fetches MSBs only and performs the computation; if the confidence is low, it fetches LSBs and recomputes the attention outputs, trading computation for memory reduction. Extensive experiments on 30 benchmarks show that, on average, SpAtten reduces DRAM access by 10.0x with no accuracy loss, and achieves 1.6x, 3.0x, 162x, 347x speedup, and 1,4x, 3.2x, 1193x, 4059x energy savings over A3 accelerator, MNNFast accelerator, TITAN Xp GPU, Xeon CPU, respectively.

SpAtten: Efficient Sparse Attention Architecture with Cascade Token and Head Pruning

TL;DR

SpAtten tackles the attention bottleneck by co-designing algorithms and hardware to exploit activation sparsity and precision granularity. It introduces cascade token pruning, cascade head pruning, and progressive quantization, plus a high-throughput top-k engine to enable on-the-fly pruning without retraining. The approach yields dramatic DRAM traffic reductions and large speedups across BERT and GPT-2 workloads on diverse platforms, while preserving accuracy. This work demonstrates practical pathways to memory-bound NLP inference acceleration through targeted pruning and adaptive quantization in both software and hardware.

Abstract

The attention mechanism is becoming increasingly popular in Natural Language Processing (NLP) applications, showing superior performance than convolutional and recurrent architectures. However, attention becomes the compution bottleneck because of its quadratic computational complexity to input length, complicated data movement and low arithmetic intensity. Moreover, existing NN accelerators mainly focus on optimizing convolutional or recurrent models, and cannot efficiently support attention. In this paper, we present SpAtten, an efficient algorithm-architecture co-design that leverages token sparsity, head sparsity, and quantization opportunities to reduce the attention computation and memory access. Inspired by the high redundancy of human languages, we propose the novel cascade token pruning to prune away unimportant tokens in the sentence. We also propose cascade head pruning to remove unessential heads. Cascade pruning is fundamentally different from weight pruning since there is no trainable weight in the attention mechanism, and the pruned tokens and heads are selected on the fly. To efficiently support them on hardware, we design a novel top-k engine to rank token and head importance scores with high throughput. Furthermore, we propose progressive quantization that first fetches MSBs only and performs the computation; if the confidence is low, it fetches LSBs and recomputes the attention outputs, trading computation for memory reduction. Extensive experiments on 30 benchmarks show that, on average, SpAtten reduces DRAM access by 10.0x with no accuracy loss, and achieves 1.6x, 3.0x, 162x, 347x speedup, and 1,4x, 3.2x, 1193x, 4059x energy savings over A3 accelerator, MNNFast accelerator, TITAN Xp GPU, Xeon CPU, respectively.

Paper Structure

This paper contains 26 sections, 2 equations, 23 figures, 4 tables, 3 algorithms.

Figures (23)

  • Figure 1: Cascade token and head pruning removes redundant tokens and heads globally across layers. Evaluated with BERT-Base on SST-2 dataset.
  • Figure 2: End-to-End GPT-2 latency breakdown on various platforms, and attention latency breakdown on TITAN Xp GPU. Attention accounts for over 50% of total latency. Data movements account for 73% of attention latency.
  • Figure 3: NLP model architecture with attention. BERT only contains the summarization stage. GPT-2 contains summarization and generation stages.
  • Figure 4: Cascade token pruning removes redundant tokens and corresponding entire Q K V vectors according to the cumulative token importance scores computed from $attention\_prob$. Cascade head pruning removes unimportant heads and corresponding chunks in all Q K V vectors according to the cumulative head important scores computed from $attention\_out$. Once a token/head is pruned, it will never appear in any following layers, thus named cascade pruning. More tokens and heads are pruned away as the layer goes deeper.
  • Figure 5: Attention probabilities for BERT are summed over each column to get importance scores. Tokens with small importance scores are pruned.
  • ...and 18 more figures