MLP Memory: A Retriever-Pretrained Memory for Large Language Models
Rubin Wei, Jiaqi Cao, Jiarui Wang, Jushi Kai, Qipeng Guo, Bowen Zhou, Zhouhan Lin
TL;DR
MLP Memory proposes a lightweight parametric memory that internalizes retrieval patterns by pretraining an all-MLP to imitate a $k$NN retriever over the entire pretraining corpus and then interpolating its output with the base LLM during inference. This approach aims to combine the factual grounding of retrieval with the efficiency of parametric models, avoiding the latency and integration issues of RAG while mitigating catastrophic forgetting seen in CPT/LoRA. Empirically, MLP Memory yields strong scaling behavior, substantial gains on QA benchmarks, broad improvements on general NLP tasks, and notable hallucination reductions, while delivering 2.5x faster inference than RAG and constant-time performance regardless of datastore size. The work presents a practical pathway to bridge parametric and non-parametric memory, offering improved knowledge utilization with scalable, efficient deployment among large language models.
Abstract
Modern approaches to enhancing Large Language Models' factual accuracy and knowledge utilization face a fundamental trade-off: non-parametric retrieval-augmented generation (RAG) provides flexible access to external knowledge but suffers from high inference latency and shallow integration, while parametric fine-tuning methods like LoRA risk catastrophic forgetting and degraded general capabilities. In this work, we propose MLP Memory, a lightweight parametric module that learns to internalize retrieval patterns without explicit document access. By pretraining an MLP to imitate a $k$NN retriever's behavior on the entire pretraining dataset, we create a differentiable memory component that captures the benefits of retrieval-based knowledge access in a fully parametric form. Our architecture integrates this pretrained MLP Memory with Transformer decoders through simple probability interpolation, yielding 17.5\% and 24.1\% scaling gains on WikiText-103 and Web datasets, respectively. It further achieves 12.3\% relative improvement on five question-answering benchmarks and 5.2 points absolute gain across nine general NLP tasks, while reducing hallucinations by up to 10 points on HaluEval. Moreover, MLP Memory delivers 2.5$\times$ faster inference than RAG with superior accuracy. Our findings show that learning retrieval patterns parametrically bridges the gap between efficient inference and effective knowledge access, offering a practical alternative to both RAG and fine-tuning approaches.
