Table of Contents
Fetching ...

Mask-Enhanced Autoregressive Prediction: Pay Less Attention to Learn More

Xialie Zhuang, Zhikai Jia, Jianjin Li, Zhenyu Zhang, Li Shen, Zheng Cao, Shiwei Liu

TL;DR

Mask-Enhanced Autoregressive Prediction (MEAP) adds a lightweight masked-token signal to standard decoder-only next-token prediction, enabling masked language modeling benefits without encoder-decoder architectures or extra pretraining/inference costs. By randomly masking a small fraction of input tokens during pretraining and duplicating masked sequences during fine-tuning, MEAP improves information retrieval and long-context reasoning while preserving language modeling capability and improving data efficiency. Empirical results show MEAP yields substantial gains on Needle-in-a-Haystack, multi-document QA, and long-context tasks, reduces contextual hallucinations, and generalizes across model families, with notable efficiency advantages. The authors attribute this to MEAP producing more distinguishable attention by focusing on non-masked tokens, thereby sharpening task-relevant signals for the model.

Abstract

Large Language Models (LLMs) are discovered to suffer from accurately retrieving key information. To address this, we propose Mask-Enhanced Autoregressive Prediction (MEAP), a simple yet effective training paradigm that seamlessly integrates Masked Language Modeling (MLM) into Next-Token Prediction (NTP) to enhance the latter's in-context retrieval capabilities. Specifically, MEAP first randomly masks a small fraction of input tokens and then directly performs the standard next-token prediction autoregressive using a decoder-only Transformer. MEAP eliminates the need for bidirectional attention or encoder-decoder architectures for MLM, incurring no additional computational overhead during pre-training or inference. Intensive experiments demonstrate that MEAP substantially outperforms NTP on key information retrieval and long-context reasoning tasks, while performing on par or better on commonsense reasoning tasks. The benefits of MEAP also extend to supervised fine-tuning, where it shows remarkable advantages in lost-in-the-middle scenarios, outperforming NTP by 11.77 percentage points. Our analysis indicates that MEAP's effectiveness arises from its ability to promote more distinguishable attention scores by concentrating on a reduced set of non-masked tokens. This mechanism improves the model's focus on task-relevant signals while mitigating the influence of peripheral context. These findings position MEAP as a promising training paradigm for large language models.

Mask-Enhanced Autoregressive Prediction: Pay Less Attention to Learn More

TL;DR

Mask-Enhanced Autoregressive Prediction (MEAP) adds a lightweight masked-token signal to standard decoder-only next-token prediction, enabling masked language modeling benefits without encoder-decoder architectures or extra pretraining/inference costs. By randomly masking a small fraction of input tokens during pretraining and duplicating masked sequences during fine-tuning, MEAP improves information retrieval and long-context reasoning while preserving language modeling capability and improving data efficiency. Empirical results show MEAP yields substantial gains on Needle-in-a-Haystack, multi-document QA, and long-context tasks, reduces contextual hallucinations, and generalizes across model families, with notable efficiency advantages. The authors attribute this to MEAP producing more distinguishable attention by focusing on non-masked tokens, thereby sharpening task-relevant signals for the model.

Abstract

Large Language Models (LLMs) are discovered to suffer from accurately retrieving key information. To address this, we propose Mask-Enhanced Autoregressive Prediction (MEAP), a simple yet effective training paradigm that seamlessly integrates Masked Language Modeling (MLM) into Next-Token Prediction (NTP) to enhance the latter's in-context retrieval capabilities. Specifically, MEAP first randomly masks a small fraction of input tokens and then directly performs the standard next-token prediction autoregressive using a decoder-only Transformer. MEAP eliminates the need for bidirectional attention or encoder-decoder architectures for MLM, incurring no additional computational overhead during pre-training or inference. Intensive experiments demonstrate that MEAP substantially outperforms NTP on key information retrieval and long-context reasoning tasks, while performing on par or better on commonsense reasoning tasks. The benefits of MEAP also extend to supervised fine-tuning, where it shows remarkable advantages in lost-in-the-middle scenarios, outperforming NTP by 11.77 percentage points. Our analysis indicates that MEAP's effectiveness arises from its ability to promote more distinguishable attention scores by concentrating on a reduced set of non-masked tokens. This mechanism improves the model's focus on task-relevant signals while mitigating the influence of peripheral context. These findings position MEAP as a promising training paradigm for large language models.

Paper Structure

This paper contains 30 sections, 4 equations, 7 figures, 19 tables.

Figures (7)

  • Figure 1: Overview of next token prediction, masked language modeling, and our MEAP.
  • Figure 2: Training frameworks of MEAP: Left (Pre-training): A certain portion of input tokens is randomly masked, followed by standard next-token prediction (NTP). Right (Fine-tuning): Training samples are duplicated, and the random masking strategy is applied to the copied sequences. Standard NTP is then performed on the modified input for fine-tuning.
  • Figure 3: Performance comparison between NTP and MEAP on Needle In A Haystack. Scores are computed using ROUGE-1, measuring unigram overlap between model responses and expected answers.
  • Figure 4: Long-context reasoning performance comparison between MEAP and NTP on the Multi-Needle Reasoning Task (M-RS) across different context lengths.
  • Figure 5: Comparison of fine-tuning efficiency between MEAP and NTP. 'MEAP-n' refers to MEAP training for n epoch.
  • ...and 2 more figures