Table of Contents
Fetching ...

StableMask: Refining Causal Masking in Decoder-only Transformer

Qingyu Yin, Xuzheng He, Xiang Zhuang, Yu Zhao, Jianhua Yao, Xiaoyu Shen, Qiang Zhang

TL;DR

StableMask introduces a parameter-free refinement of the decoder-only Transformer causal mask by injecting pseudo-attention and applying a progressively decaying mask ratio. This dual mechanism balances attention distributions to alleviate disproportional attention and enables encoding of absolute positional information, addressing key limitations of Softmax-based attention and RPE. The approach yields theoretical guarantees and empirical gains across 71M to 1.4B parameter models, improves extrapolation with minimal disruption to existing encodings, and integrates with hardware-accelerated attention frameworks like FlashAttention. Practically, StableMask enhances perplexity and downstream task performance, while offering an efficient inference variant (SM-I) that maintains cache-friendly operation and compatible optimization with current Transformer ecosystems.

Abstract

The decoder-only Transformer architecture with causal masking and relative position encoding (RPE) has become the de facto choice in language modeling. Despite its exceptional performance across various tasks, we have identified two limitations: First, it requires all attention scores to be non-zero and sum up to 1, even if the current embedding has sufficient self-contained information. This compels the model to assign disproportional excessive attention to specific tokens. Second, RPE-based Transformers are not universal approximators due to their limited capacity at encoding absolute positional information, which limits their application in position-critical tasks. In this work, we propose StableMask: a parameter-free method to address both limitations by refining the causal mask. It introduces pseudo-attention values to balance attention distributions and encodes absolute positional information via a progressively decreasing mask ratio. StableMask's effectiveness is validated both theoretically and empirically, showing significant enhancements in language models with parameter sizes ranging from 71M to 1.4B across diverse datasets and encoding methods. We further show that it naturally supports (1) efficient extrapolation without special tricks such as StreamingLLM and (2) easy integration with existing attention optimization techniques.

StableMask: Refining Causal Masking in Decoder-only Transformer

TL;DR

StableMask introduces a parameter-free refinement of the decoder-only Transformer causal mask by injecting pseudo-attention and applying a progressively decaying mask ratio. This dual mechanism balances attention distributions to alleviate disproportional attention and enables encoding of absolute positional information, addressing key limitations of Softmax-based attention and RPE. The approach yields theoretical guarantees and empirical gains across 71M to 1.4B parameter models, improves extrapolation with minimal disruption to existing encodings, and integrates with hardware-accelerated attention frameworks like FlashAttention. Practically, StableMask enhances perplexity and downstream task performance, while offering an efficient inference variant (SM-I) that maintains cache-friendly operation and compatible optimization with current Transformer ecosystems.

Abstract

The decoder-only Transformer architecture with causal masking and relative position encoding (RPE) has become the de facto choice in language modeling. Despite its exceptional performance across various tasks, we have identified two limitations: First, it requires all attention scores to be non-zero and sum up to 1, even if the current embedding has sufficient self-contained information. This compels the model to assign disproportional excessive attention to specific tokens. Second, RPE-based Transformers are not universal approximators due to their limited capacity at encoding absolute positional information, which limits their application in position-critical tasks. In this work, we propose StableMask: a parameter-free method to address both limitations by refining the causal mask. It introduces pseudo-attention values to balance attention distributions and encodes absolute positional information via a progressively decreasing mask ratio. StableMask's effectiveness is validated both theoretically and empirically, showing significant enhancements in language models with parameter sizes ranging from 71M to 1.4B across diverse datasets and encoding methods. We further show that it naturally supports (1) efficient extrapolation without special tricks such as StreamingLLM and (2) easy integration with existing attention optimization techniques.
Paper Structure (33 sections, 4 theorems, 30 equations, 6 figures, 6 tables)

This paper contains 33 sections, 4 theorems, 30 equations, 6 figures, 6 tables.

Key Result

Theorem 4.1

Let $X = [\boldsymbol{x}_1, \cdots, \boldsymbol{x}_n]_n$ be an input sequence of length $n$ to the StableMask model $f^{\text{(SM)}}_T$. Then, the first layer of $f^{\text{(SM)}}_T$ can recover absolute positions $[1,2,\dots,n]$ in the hidden state $\Omega^{(1)}$. That is, there exist $W_Q$, $W_K$,

Figures (6)

  • Figure 1: (a) Visual comparison of attention heads with and without StableMask on the OpenLLaMA 1.4B model. (b) The attention allocation to various types of tokens (excluding the initial token) at two different positions and the trend of attention allocation to the initial token over positions, averaged over heads. Blue: The original Transformer exhibits a clear disproportional attention issue. Green: StableMask effectively rectifies the proportion of attention allocation. (c) Experimental Results showing RPE’s inability to encode absolute position (Blue). StableMask solves the issue of RPE’s inability to encode absolute position (Green).
  • Figure 2: (a) Illustration of the StableMask mechanism. (b) StableMask integrates with the softmax operation, replacing the traditional causal mask. (c) The attention score matrix is first cleared of attention values in the upper triangular part using the $C$ matrix, then pseudo-attention scores are added using the $P$ matrix followed by the softmax computation. (d) After the softmax operation, the remaining attention probabilities in the upper triangular part are cleared using $C$ to ensure the causal decoding property. (e) The $C$ matrix has zeros in the upper triangular part and ones in the lower triangular part, while the $P$ matrix has linear decay in the upper triangular part and zeros in the lower triangular part. $\gamma$ is a hyperparameter. (f) StableMask for inference. An input sequence needs a suffix.
  • Figure 3: StableMask for Inference. The original StableMask implementation needs to recompute the softmax result for the attention score matrix because additional mask values are added. StableMask for Inference introduces a factor $\tau$ to fix the situation to be in the form of the maximum training length.
  • Figure 4: (abc): Scaling Curve of models from 160M to 1.4B across different positional encodings. (d): extrapolation results (with window attention). StableMask consistently improves the model performance while enabling effective extrapolation.
  • Figure 5: Inference latency test on OpenLLaMA 1.4B. Our proposed StableMask adapted for fast inference (SM-I) significantly reduces the running latency.
  • ...and 1 more figures

Theorems & Definitions (11)

  • Theorem 4.1
  • Definition 1.1
  • Definition 1.2
  • Definition 1.3
  • Proposition 1.4
  • proof
  • Definition 1.5
  • Theorem 2.1
  • proof
  • Theorem 3.1
  • ...and 1 more