Mask-Enhanced Autoregressive Prediction: Pay Less Attention to Learn More
Xialie Zhuang, Zhikai Jia, Jianjin Li, Zhenyu Zhang, Li Shen, Zheng Cao, Shiwei Liu
TL;DR
Mask-Enhanced Autoregressive Prediction (MEAP) adds a lightweight masked-token signal to standard decoder-only next-token prediction, enabling masked language modeling benefits without encoder-decoder architectures or extra pretraining/inference costs. By randomly masking a small fraction of input tokens during pretraining and duplicating masked sequences during fine-tuning, MEAP improves information retrieval and long-context reasoning while preserving language modeling capability and improving data efficiency. Empirical results show MEAP yields substantial gains on Needle-in-a-Haystack, multi-document QA, and long-context tasks, reduces contextual hallucinations, and generalizes across model families, with notable efficiency advantages. The authors attribute this to MEAP producing more distinguishable attention by focusing on non-masked tokens, thereby sharpening task-relevant signals for the model.
Abstract
Large Language Models (LLMs) are discovered to suffer from accurately retrieving key information. To address this, we propose Mask-Enhanced Autoregressive Prediction (MEAP), a simple yet effective training paradigm that seamlessly integrates Masked Language Modeling (MLM) into Next-Token Prediction (NTP) to enhance the latter's in-context retrieval capabilities. Specifically, MEAP first randomly masks a small fraction of input tokens and then directly performs the standard next-token prediction autoregressive using a decoder-only Transformer. MEAP eliminates the need for bidirectional attention or encoder-decoder architectures for MLM, incurring no additional computational overhead during pre-training or inference. Intensive experiments demonstrate that MEAP substantially outperforms NTP on key information retrieval and long-context reasoning tasks, while performing on par or better on commonsense reasoning tasks. The benefits of MEAP also extend to supervised fine-tuning, where it shows remarkable advantages in lost-in-the-middle scenarios, outperforming NTP by 11.77 percentage points. Our analysis indicates that MEAP's effectiveness arises from its ability to promote more distinguishable attention scores by concentrating on a reduced set of non-masked tokens. This mechanism improves the model's focus on task-relevant signals while mitigating the influence of peripheral context. These findings position MEAP as a promising training paradigm for large language models.
