Table of Contents
Fetching ...

Selective Attention Improves Transformer

Yaniv Leviathan, Matan Kalman, Yossi Matias

TL;DR

Selective Attention addresses the quadratic attention cost in transformers by introducing a parameter-free masking mechanism that allows tokens to forget previous tokens through a masking matrix S and an accumulated mask F. It uses a selection function that reuses attention head outputs, applies constraints, and accumulates masking to influence future token attention, while Context Pruning evicts masked tokens from the KV-cache during inference. Across extensive experiments on decoder-only transformers trained on C4, selective attention yields perplexity gains across model sizes and context lengths, matching or surpassing standard transformers with about 2× more attention heads and parameters, while delivering substantial memory savings (up to 47× in some settings) through pruning. The approach is simple, parameter-free, and demonstrates improvements on downstream tasks and T5 configurations, indicating broad practicality and potential as a robust default for improving both quality and inference efficiency in transformers.

Abstract

Unneeded elements in the attention's context degrade performance. We introduce Selective Attention, a simple parameter-free change to the standard attention mechanism which reduces attention to unneeded elements. Selective attention consistently improves language modeling and downstream task performance in a variety of model sizes and context lengths. For example, transformers trained with the language modeling objective on C4 with selective attention perform language modeling equivalently to standard transformers with ~2X more heads and parameters in their attention modules. Selective attention also allows decreasing the size of the attention's context buffer, leading to meaningful reductions in the memory and compute requirements during inference. For example, transformers trained on C4 with context sizes of 512, 1,024, and 2,048 need 16X, 25X, and 47X less memory for their attention module, respectively, when equipped with selective attention, as those without selective attention, with the same validation perplexity.

Selective Attention Improves Transformer

TL;DR

Selective Attention addresses the quadratic attention cost in transformers by introducing a parameter-free masking mechanism that allows tokens to forget previous tokens through a masking matrix S and an accumulated mask F. It uses a selection function that reuses attention head outputs, applies constraints, and accumulates masking to influence future token attention, while Context Pruning evicts masked tokens from the KV-cache during inference. Across extensive experiments on decoder-only transformers trained on C4, selective attention yields perplexity gains across model sizes and context lengths, matching or surpassing standard transformers with about 2× more attention heads and parameters, while delivering substantial memory savings (up to 47× in some settings) through pruning. The approach is simple, parameter-free, and demonstrates improvements on downstream tasks and T5 configurations, indicating broad practicality and potential as a robust default for improving both quality and inference efficiency in transformers.

Abstract

Unneeded elements in the attention's context degrade performance. We introduce Selective Attention, a simple parameter-free change to the standard attention mechanism which reduces attention to unneeded elements. Selective attention consistently improves language modeling and downstream task performance in a variety of model sizes and context lengths. For example, transformers trained with the language modeling objective on C4 with selective attention perform language modeling equivalently to standard transformers with ~2X more heads and parameters in their attention modules. Selective attention also allows decreasing the size of the attention's context buffer, leading to meaningful reductions in the memory and compute requirements during inference. For example, transformers trained on C4 with context sizes of 512, 1,024, and 2,048 need 16X, 25X, and 47X less memory for their attention module, respectively, when equipped with selective attention, as those without selective attention, with the same validation perplexity.
Paper Structure (29 sections, 2 equations, 12 figures, 8 tables)

This paper contains 29 sections, 2 equations, 12 figures, 8 tables.

Figures (12)

  • Figure 1: A visualization of the masking by selective attention (red strike-through) and attention strength (averaged across heads, blue highlight) for different tasks (see Section \ref{['sec:motivating examples']}).
  • Figure 2: A sketch implementation of selective attention. The colored lines are the additions to standard attention.
  • Figure 3: (Left) The validation perplexity of a $d=12$ transformer, with (blue) and without (orange) selective attention, for varying context sizes. (Right) The validation perplexity of transformers of various sizes, with (blue) and without (orange) selective attention, for a context size of 512.
  • Figure 4: Perplexity of transformers of various sizes with and without selective attention. For the cases without selective attention we add additional attention heads with their respective parameters (i.e. increase the sizes of all projection matrices). Transformers with selective attention perform equivalently to those with standard attention with $\sim$2X as many heads and parameters.
  • Figure 5: Visualization of the $F$ matrix (greener is lower, i.e. less masking) for a $d=12$ transformer for the text in Appendix \ref{['appendix:example details']}.
  • ...and 7 more figures