Table of Contents
Fetching ...

Selective Attention: Enhancing Transformer through Principled Context Control

Xuechen Zhang, Xiangyu Chang, Mingchen Li, Amit Roy-Chowdhury, Jiasi Chen, Samet Oymak

TL;DR

This work introduces the Selective Self-Attention (SSA) layer that augments the softmax nonlinearity with a principled temperature scaling strategy and demonstrates that this alleviates attention dilution, aids the optimization process, and enhances the model's ability to control softmax spikiness of individual queries.

Abstract

The attention mechanism within the transformer architecture enables the model to weigh and combine tokens based on their relevance to the query. While self-attention has enjoyed major success, it notably treats all queries $q$ in the same way by applying the mapping $V^\top\text{softmax}(Kq)$, where $V,K$ are the value and key embeddings respectively. In this work, we argue that this uniform treatment hinders the ability to control contextual sparsity and relevance. As a solution, we introduce the $\textit{Selective Self-Attention}$ (SSA) layer that augments the softmax nonlinearity with a principled temperature scaling strategy. By controlling temperature, SSA adapts the contextual sparsity of the attention map to the query embedding and its position in the context window. Through theory and experiments, we demonstrate that this alleviates attention dilution, aids the optimization process, and enhances the model's ability to control softmax spikiness of individual queries. We also incorporate temperature scaling for value embeddings and show that it boosts the model's ability to suppress irrelevant/noisy tokens. Notably, SSA is a lightweight method which introduces less than 0.5% new parameters through a weight-sharing strategy and can be fine-tuned on existing LLMs. Extensive empirical evaluations demonstrate that SSA-equipped models achieve a noticeable and consistent accuracy improvement on language modeling benchmarks.

Selective Attention: Enhancing Transformer through Principled Context Control

TL;DR

This work introduces the Selective Self-Attention (SSA) layer that augments the softmax nonlinearity with a principled temperature scaling strategy and demonstrates that this alleviates attention dilution, aids the optimization process, and enhances the model's ability to control softmax spikiness of individual queries.

Abstract

The attention mechanism within the transformer architecture enables the model to weigh and combine tokens based on their relevance to the query. While self-attention has enjoyed major success, it notably treats all queries in the same way by applying the mapping , where are the value and key embeddings respectively. In this work, we argue that this uniform treatment hinders the ability to control contextual sparsity and relevance. As a solution, we introduce the (SSA) layer that augments the softmax nonlinearity with a principled temperature scaling strategy. By controlling temperature, SSA adapts the contextual sparsity of the attention map to the query embedding and its position in the context window. Through theory and experiments, we demonstrate that this alleviates attention dilution, aids the optimization process, and enhances the model's ability to control softmax spikiness of individual queries. We also incorporate temperature scaling for value embeddings and show that it boosts the model's ability to suppress irrelevant/noisy tokens. Notably, SSA is a lightweight method which introduces less than 0.5% new parameters through a weight-sharing strategy and can be fine-tuned on existing LLMs. Extensive empirical evaluations demonstrate that SSA-equipped models achieve a noticeable and consistent accuracy improvement on language modeling benchmarks.

Paper Structure

This paper contains 24 sections, 4 theorems, 18 equations, 4 figures, 8 tables.

Key Result

Lemma 1

Let $\bm{W}=\bm{W}_q\bm{W}_k^\top\in\mathbb{R}^{d\times d}$ be the combined query-key matrix. Let $\bm{a},\bm{b}\in\mathbb{R}^d$ be unit norm token embeddings associated with the specific and general token respectively. Suppose we wish to achieve specificities $\text{spec}_{\bm{W}}(\bm{a})\geq L_a$

Figures (4)

  • Figure 1: A quotation by Steve Jobs. We highlight tokens according to their temperatures learned by the SSA layer. Darker colors correspond to lower temperatures and receive a sparser attention map.
  • Figure 2: The operator norm of $\bm{W}$ with and without Query-temperature scaling, scaled by $\times 10^{3}$. The figure depicts the distribution across 1000 tokens. The dashed line is the average norm. Notably, the norm of the vanilla attention layer is approximately three times larger than that of SSA(dashed red line compare to green line). Furthermore, the vanilla attention layer exhibits a lower spikiness score (0.39) compared to SSA (0.26), where a lower value indicates higher spikiness.
  • Figure 3: We compare 1-layer SSA and 1-layer attention when solving next-token prediction on a small vocabulary of size 8. (a) is the graph associated to the token transition dynamics. (b) is the the pairwise token transition matrix of this vocabulary. Each row of ${\bm{P}}_\star$ represents an attention map where a particular token is the query and all tokens in the vocabulary serve as keys (see Sec \ref{['query benefit sec']} for details). The transition matrix $\hat{{\bm{P}}}$ estimated by SSA in (c) is sharper and more closely resembles the optimal ${\bm{P}}_\star$. SSA achieves a smaller cross-entropy loss compared to vanilla attention, 0.009 vs 0.0126. The $\ell_1$ approximation error of the attention map of SSA is also smaller than that of vanilla attention, 0.358 vs 0.543.
  • Figure 4: Comparison of training curves. SSA provides reasonable benefits in terms of training speedup.

Theorems & Definitions (10)

  • Definition 1: Selective Self-Attention (SSA)
  • Lemma 1
  • Proposition 1
  • Proposition 2
  • proof
  • proof
  • Lemma 2
  • proof
  • Claim 1: Benefits on attention map
  • Claim 2: Benefits on prediction