Table of Contents
Fetching ...

FastMem: Fast Memorization of Prompt Improves Context Awareness of Large Language Models

Junyi Zhu, Shuochen Liu, Yu Yu, Bo Tang, Yibo Yan, Zhiyu Li, Feiyu Xiong, Tong Xu, Matthew B. Blaschko

TL;DR

This work tackles context-awareness failures in instruction-tuned LLMs where contextual information conflicts with parametric knowledge. It introduces FastMem, an inference-time memorization approach that updates only the last FFN module to memorize the prompt, optimally formulated as $L_{NTP}(x; \boldsymbol{\theta}) + \lambda L_{KL}(x; \boldsymbol{\theta})$ and stabilized via memorization control tokens. The method can be combined with decoding strategies such as Contrastive Decoding (CD) and DoLa to further boost robustness. Empirically, FastMem yields substantial gains across QA and summarization tasks, e.g., improving Llama 3-8B-Instruct on NQ-SWAP from 59.1% to 71.6% and reducing Qwen 1.5-4B-Chat output-structure failures from 34.9% to 25.5%, with minimal memory and latency overhead.

Abstract

Large language models (LLMs) excel in generating coherent text, but they often struggle with context awareness, leading to inaccuracies in tasks requiring faithful adherence to provided information. We introduce FastMem, a novel method designed to enhance instruction fine-tuned LLMs' context awareness through fast memorization of the prompt. FastMem maximizes the likelihood of the prompt before inference by updating only the last Feed-Forward Network (FFN) module. This targeted approach ensures efficient optimization without overfitting, significantly improving the model's ability to comprehend and accurately follow the context. Our experiments demonstrate substantial gains in reading comprehension, text summarization and adherence to output structures. For instance, FastMem improves the accuracy of Llama 3-8B-Inst on the NQ-SWAP dataset from 59.1% to 71.6%, and reduces the output structure failure rate of Qwen 1.5-4B-Chat from 34.9% to 25.5%. Extensive experimental results highlight FastMem's potential to offer a robust solution to enhance the reliability and accuracy of LLMs in various applications. Our code is available at: https://github.com/IAAR-Shanghai/FastMem

FastMem: Fast Memorization of Prompt Improves Context Awareness of Large Language Models

TL;DR

This work tackles context-awareness failures in instruction-tuned LLMs where contextual information conflicts with parametric knowledge. It introduces FastMem, an inference-time memorization approach that updates only the last FFN module to memorize the prompt, optimally formulated as and stabilized via memorization control tokens. The method can be combined with decoding strategies such as Contrastive Decoding (CD) and DoLa to further boost robustness. Empirically, FastMem yields substantial gains across QA and summarization tasks, e.g., improving Llama 3-8B-Instruct on NQ-SWAP from 59.1% to 71.6% and reducing Qwen 1.5-4B-Chat output-structure failures from 34.9% to 25.5%, with minimal memory and latency overhead.

Abstract

Large language models (LLMs) excel in generating coherent text, but they often struggle with context awareness, leading to inaccuracies in tasks requiring faithful adherence to provided information. We introduce FastMem, a novel method designed to enhance instruction fine-tuned LLMs' context awareness through fast memorization of the prompt. FastMem maximizes the likelihood of the prompt before inference by updating only the last Feed-Forward Network (FFN) module. This targeted approach ensures efficient optimization without overfitting, significantly improving the model's ability to comprehend and accurately follow the context. Our experiments demonstrate substantial gains in reading comprehension, text summarization and adherence to output structures. For instance, FastMem improves the accuracy of Llama 3-8B-Inst on the NQ-SWAP dataset from 59.1% to 71.6%, and reduces the output structure failure rate of Qwen 1.5-4B-Chat from 34.9% to 25.5%. Extensive experimental results highlight FastMem's potential to offer a robust solution to enhance the reliability and accuracy of LLMs in various applications. Our code is available at: https://github.com/IAAR-Shanghai/FastMem
Paper Structure (46 sections, 6 equations, 16 figures, 7 tables)

This paper contains 46 sections, 6 equations, 16 figures, 7 tables.

Figures (16)

  • Figure 1: An illustrative example where an LLM cannot follow the reference due to its conflict with the LLM's prior knowledge (a). In such cases, LLM often exhibits high perplexity on the conflicted information. Our proposed FastMem addresses this issue by enabling the model to memorize the reference text (thereby reducing perplexity) before conducting inference (b).
  • Figure 2: Overall frameworks of our proposed FastMem and its integration with decoding strategies.
  • Figure 3: Model performance and perplexity on NQ and NQ-SWAP datasets.
  • Figure 4: Q&A template for Llama 3. Control tokens are highlighted in blue. Keys are colored in gray and are replaced by corresponding content from the dataset.
  • Figure 5: Q&A memorization template for Qwen 1.5-4B-Chat. Instructions are embedded within the memorization text to improve the output structure. For inference template, refer to \ref{['fig:control_tokens']}. Control tokens are highlighted in blue, while keys are shown in gray and are replaced with corresponding content from the dataset.
  • ...and 11 more figures