A2SF: Accumulative Attention Scoring with Forgetting Factor for Token Pruning in Transformer Decoder
Hyun-rae Jo, Dongkun Shin
TL;DR
The paper tackles the memory bottleneck of the Transformer Decoder KV Cache during long-sequence generation by introducing Accumulative Attention Score with Forgetting Factor (A2SF). A2SF incorporates a Forgetting Factor $\alpha$ into the accumulation of attention scores, $A^h_{n,k} = \sum_{q=1}^{n} \alpha^{n-q} \cdot S^h_{q,k}$ with $0<\alpha<1$, to downweight older contributions and achieve fair token comparisons under causal masking. Empirical results on OPT and LLaMA variants show that A2SF improves accuracy over prior KV-cache pruning methods, with up to about 7.8% gains in 1-shot and 5.1% in 0-shot on LLaMA 2 7B, and masks that closely align with ideal pruning. The work demonstrates that a tunable forgetting mechanism can enhance token pruning effectiveness without retraining and suggests integration potential with existing KV Cache processing pipelines.
Abstract
Recently, large language models (LLM) based on transformers are facing memory bottleneck issues due to KV cache, especially in long sequence handling. Previous researches proposed KV cache compression techniques that identify insignificant tokens based on Accumulative Attention Scores and removes their items from KV cache, noting that only few tokens play an important role in attention operations. However, we have observed that the existing Accumulative Attention Score is not suitable for the transformer decoder structure. In the decoder model, the number of times the Attention Score accumulates varies depending on the order of token appearance due to the effect of masking, causing an uneven comparison between tokens. To solve this, we propose Accumulative Attention Score with Forgetting Factor (A2SF) technique, which introduces a Forgetting Factor in the Attention Score accumulation process. A2SF applies a penalty to the past Attention Score generated from old tokens by repeatedly multiplying the Forgetting Factor to the Attention Score over time. Therefore, older tokens receive a larger penalty, providing fairness among different ages of tokens. Through the fair comparison among tokens, we can more effectively select important tokens. We have verified the accuracy improvement through A2SF in the OPT and LLaMA models and A2SF improves the accuracy of LLaMA 2 by up to 7.8% and 5.1% on 1-shot and 0-shot.
