Table of Contents
Fetching ...

MLP Memory: A Retriever-Pretrained Memory for Large Language Models

Rubin Wei, Jiaqi Cao, Jiarui Wang, Jushi Kai, Qipeng Guo, Bowen Zhou, Zhouhan Lin

TL;DR

MLP Memory proposes a lightweight parametric memory that internalizes retrieval patterns by pretraining an all-MLP to imitate a $k$NN retriever over the entire pretraining corpus and then interpolating its output with the base LLM during inference. This approach aims to combine the factual grounding of retrieval with the efficiency of parametric models, avoiding the latency and integration issues of RAG while mitigating catastrophic forgetting seen in CPT/LoRA. Empirically, MLP Memory yields strong scaling behavior, substantial gains on QA benchmarks, broad improvements on general NLP tasks, and notable hallucination reductions, while delivering 2.5x faster inference than RAG and constant-time performance regardless of datastore size. The work presents a practical pathway to bridge parametric and non-parametric memory, offering improved knowledge utilization with scalable, efficient deployment among large language models.

Abstract

Modern approaches to enhancing Large Language Models' factual accuracy and knowledge utilization face a fundamental trade-off: non-parametric retrieval-augmented generation (RAG) provides flexible access to external knowledge but suffers from high inference latency and shallow integration, while parametric fine-tuning methods like LoRA risk catastrophic forgetting and degraded general capabilities. In this work, we propose MLP Memory, a lightweight parametric module that learns to internalize retrieval patterns without explicit document access. By pretraining an MLP to imitate a $k$NN retriever's behavior on the entire pretraining dataset, we create a differentiable memory component that captures the benefits of retrieval-based knowledge access in a fully parametric form. Our architecture integrates this pretrained MLP Memory with Transformer decoders through simple probability interpolation, yielding 17.5\% and 24.1\% scaling gains on WikiText-103 and Web datasets, respectively. It further achieves 12.3\% relative improvement on five question-answering benchmarks and 5.2 points absolute gain across nine general NLP tasks, while reducing hallucinations by up to 10 points on HaluEval. Moreover, MLP Memory delivers 2.5$\times$ faster inference than RAG with superior accuracy. Our findings show that learning retrieval patterns parametrically bridges the gap between efficient inference and effective knowledge access, offering a practical alternative to both RAG and fine-tuning approaches.

MLP Memory: A Retriever-Pretrained Memory for Large Language Models

TL;DR

MLP Memory proposes a lightweight parametric memory that internalizes retrieval patterns by pretraining an all-MLP to imitate a NN retriever over the entire pretraining corpus and then interpolating its output with the base LLM during inference. This approach aims to combine the factual grounding of retrieval with the efficiency of parametric models, avoiding the latency and integration issues of RAG while mitigating catastrophic forgetting seen in CPT/LoRA. Empirically, MLP Memory yields strong scaling behavior, substantial gains on QA benchmarks, broad improvements on general NLP tasks, and notable hallucination reductions, while delivering 2.5x faster inference than RAG and constant-time performance regardless of datastore size. The work presents a practical pathway to bridge parametric and non-parametric memory, offering improved knowledge utilization with scalable, efficient deployment among large language models.

Abstract

Modern approaches to enhancing Large Language Models' factual accuracy and knowledge utilization face a fundamental trade-off: non-parametric retrieval-augmented generation (RAG) provides flexible access to external knowledge but suffers from high inference latency and shallow integration, while parametric fine-tuning methods like LoRA risk catastrophic forgetting and degraded general capabilities. In this work, we propose MLP Memory, a lightweight parametric module that learns to internalize retrieval patterns without explicit document access. By pretraining an MLP to imitate a NN retriever's behavior on the entire pretraining dataset, we create a differentiable memory component that captures the benefits of retrieval-based knowledge access in a fully parametric form. Our architecture integrates this pretrained MLP Memory with Transformer decoders through simple probability interpolation, yielding 17.5\% and 24.1\% scaling gains on WikiText-103 and Web datasets, respectively. It further achieves 12.3\% relative improvement on five question-answering benchmarks and 5.2 points absolute gain across nine general NLP tasks, while reducing hallucinations by up to 10 points on HaluEval. Moreover, MLP Memory delivers 2.5 faster inference than RAG with superior accuracy. Our findings show that learning retrieval patterns parametrically bridges the gap between efficient inference and effective knowledge access, offering a practical alternative to both RAG and fine-tuning approaches.

Paper Structure

This paper contains 28 sections, 10 equations, 12 figures, 9 tables.

Figures (12)

  • Figure 1: Performance and efficiency comparison. Left: accuracy across three QA benchmarks. MLP Memory consistently outperforms the base model, surpassing both parametric methods (CPT, LoRA) and non-parametric retrieval (RAG). Right: inference efficiency, measured by time to first token (TTFT, $\downarrow$ lower is better) and tokens per second (TPS, $\uparrow$ higher is better). RAG results are shown for top-5 retrieval. $k$NN-LM is accelerated via dimension reduction (4096$\rightarrow$256), and both RAG and $k$NN-LM use the Wikipedia-2021 retrieval corpus. MLP Memory uses 1B parameters.
  • Figure 2: Approaches to enhance factual accuracy and knowledge utilization. Top left: Non-parametric RAG provides flexible knowledge access but suffers from high latency. Top right: Parametric fine-tuning risks catastrophic forgetting. Bottom: MLP Memory learns retrieval patterns during training (left) and enables efficient inference without explicit retrieval (right).
  • Figure 3: Comparison of model outputs on a factual question. Despite retrieving relevant documents with correct information (highlighted in green), RAG is misled by contextual distractors and produces an incorrect answer. MLP Memory generates the correct answer without explicit retrieval.
  • Figure 4: Overview of MLP Memory architecture. (a) Inference: MLP Memory processes context representations from a specific LLM layer, generating token probabilities that are interpolated with LLM outputs for final predictions. (b) Training: MLP Memory learns to imitate retriever behavior using LLM representations as input and distributions generated by $k$NN retrievers as targets, optimized through a hybrid objective.
  • Figure 5: Power-law scaling behavior with model size $N$ and training compute $C$. (a) Scaling results compare the continued training of GPT2 (GPT2-ConTrain) with our overall model architecture (GPT2+MLP Mem) under fixed compute. Our fitted curve shows a 17.5% exponent improvement on WikiText-103. (b) On the larger Web dataset, our architecture exhibits stronger scaling gains from increased data size, with an exponent improvement of 24.1%. (c) At the GPT2-xl scale, our architecture continues to benefit from additional training on the Web dataset without overfitting.
  • ...and 7 more figures