Table of Contents
Fetching ...

Segment-Based Attention Masking for GPTs

Shahar Katz, Liran Ringel, Yaniv Romano, Lior Wolf

TL;DR

The paper addresses the limitation of standard autoregressive GPTs in leveraging future context during the prefill phase. It proposes Masked Attention by Segment (MAS), a lightweight fine-tuning approach that unmasks within predefined prompt segments during prefill and reverts to causal masking during generation, enabling bidirectional context without training new architectures. Empirical results across multiple base models and eight commonsense-reasoning tasks show consistent accuracy gains and early, sustained improvements, with ablations highlighting the importance of fine-tuning attention components. The approach offers practical benefits for chat-based systems by enabling segment-wise attention and potential system-prompt caching, while noting limitations with very long prompts and the need for task-specific tuning.

Abstract

Modern Language Models (LMs) owe much of their success to masked causal attention, the backbone of Generative Pre-Trained Transformer (GPT) models. Although GPTs can process the entire user prompt at once, the causal masking is applied to all input tokens step-by-step, mimicking the generation process. This imposes an unnecessary constraint during the initial "prefill" phase when the model processes the input prompt and generates the internal representations before producing any output tokens. In this work, attention is masked based on the known block structure at the prefill phase, followed by the conventional token-by-token autoregressive process after that. For example, in a typical chat prompt, the system prompt is treated as one block, and the user prompt as the next one. Each of these is treated as a unit for the purpose of masking, such that the first tokens in each block can access the subsequent tokens in a non-causal manner. Then, the model answer is generated in the conventional causal manner. This Segment-by-Segment scheme entails no additional computational overhead. When integrating it into models such as Llama and Qwen, state-of-the-art performance is consistently achieved.

Segment-Based Attention Masking for GPTs

TL;DR

The paper addresses the limitation of standard autoregressive GPTs in leveraging future context during the prefill phase. It proposes Masked Attention by Segment (MAS), a lightweight fine-tuning approach that unmasks within predefined prompt segments during prefill and reverts to causal masking during generation, enabling bidirectional context without training new architectures. Empirical results across multiple base models and eight commonsense-reasoning tasks show consistent accuracy gains and early, sustained improvements, with ablations highlighting the importance of fine-tuning attention components. The approach offers practical benefits for chat-based systems by enabling segment-wise attention and potential system-prompt caching, while noting limitations with very long prompts and the need for task-specific tuning.

Abstract

Modern Language Models (LMs) owe much of their success to masked causal attention, the backbone of Generative Pre-Trained Transformer (GPT) models. Although GPTs can process the entire user prompt at once, the causal masking is applied to all input tokens step-by-step, mimicking the generation process. This imposes an unnecessary constraint during the initial "prefill" phase when the model processes the input prompt and generates the internal representations before producing any output tokens. In this work, attention is masked based on the known block structure at the prefill phase, followed by the conventional token-by-token autoregressive process after that. For example, in a typical chat prompt, the system prompt is treated as one block, and the user prompt as the next one. Each of these is treated as a unit for the purpose of masking, such that the first tokens in each block can access the subsequent tokens in a non-causal manner. Then, the model answer is generated in the conventional causal manner. This Segment-by-Segment scheme entails no additional computational overhead. When integrating it into models such as Llama and Qwen, state-of-the-art performance is consistently achieved.

Paper Structure

This paper contains 17 sections, 10 equations, 7 figures, 4 tables.

Figures (7)

  • Figure 1: Causal and MAS attention. The plot shows binary values, where the y-axis represents the index of the current token, and the x-axis represents the set of indices of tokens it can attend to. MAS is inspired by the observation that input prompts are provided to the model as a whole, so they can be masked together in blocks, allowing access to future tokens within the same block of the prompt.
  • Figure 3: The average accuracy on the Commonsense Reasoning during the fine-tuning of Llama-3.2, either 1B or 3B using conventional (causal) attention or MAS.
  • Figure 4: The average accuracy on the Commonsense Reasoning when training different matrices of Llama-3.2.
  • Figure 5: The attention maps of two instances of Llama-3.2-1B fine-tuned for commonsense reasoning tasks, one with standard causal masking and the other with MAS, reveal distinct patterns: (i) L0H25 (layer 0, head 25) exhibits N-gram Patterns, indicating attention focused on sequences of consecutive tokens. (ii) L9H13 demonstrates Block-Specific Patterns, concentrated within defined blocks of the prompts. (iii) L10H23 showcases a Forward-Looking behavior, attending precisely to the next tokens within the same block. (iv) L1H20 serves as an example of Preserved Patterns.
  • Figure 6: Average accuracy on 6 of the Commonsense Reasoning tasks, as a function of random seeds and learning rates.
  • ...and 2 more figures