Filtering with Self-Attention and Storing with MLP: One-Layer Transformers Can Provably Acquire and Extract Knowledge
Ruichen Xu, Kexin Chen
TL;DR
This work develops a theoretical framework for knowledge acquisition and extraction in transformers by analyzing a simplified one-layer architecture that combines self-attention and an MLP. It proves that pre-training under next-token prediction converges to near-optimal loss and that, with sufficiently large fine-tuning data and relation-multiplicity, a post-FT out-of-distribution generalization bound of the form $\exp\left(-\dfrac{\tilde{N}_f K^2 \log|\mathcal{R}|}{2|\mathcal{R}|^2}\right)$ can be achieved, while violating these conditions leads to hallucinations. The paper also explains why low-rank fine-tuning is effective, showing the gradient concentrates on a dominant rank-1 component aligned with the Q&A format embedding, and validates these insights on synthetic data and real PopQA tasks with GPT-2 and Llama-3.2-1B. Overall, the results illuminate how self-attention and MLP interactions underpin knowledge storage and retrieval in a tractable, provable setting, offering guidance for practical FT strategies and future work on deeper architectures.
Abstract
Modern large language models (LLMs) demonstrate exceptional performance on knowledge-intensive tasks, yet the theoretical mechanisms underlying knowledge acquisition (storage and memorization) during pre-training and extraction (retrieval and recall) during inference after fine-tuning remain poorly understood. Although prior theoretical studies have explored these processes through analyses of training dynamics, they overlook critical components essential for a comprehensive theory: (1) the multi-layer perceptron (MLP), empirically identified as the primary module for knowledge storage; (2) out-of-distribution (OOD) adaptivity, which enables LLMs to generalize to unseen scenarios post-pre-training; and (3) next-token prediction, the standard autoregressive objective that encodes knowledge as conditional probabilities. In this work, we introduce, to the best of our knowledge, the first theoretical framework that addresses these limitations by examining the training dynamics of one-layer transformers. Under regularity assumptions, we establish that: (i) transformers attain near-optimal training loss during pre-training, demonstrating effective knowledge acquisition; (ii) given a sufficiently large fine-tuning dataset and appropriate data multiplicity conditions, transformers achieve low generalization error on factual knowledge acquired during pre-training but not revisited in fine-tuning, indicating robust knowledge extraction; and (iii) violation of these conditions leads to elevated generalization error, manifesting as hallucinations. Our analysis encompasses both full fine-tuning and low-rank fine-tuning, yielding insights into the efficacy of practical low-rank adaptation methods. We validate our theoretical findings through experiments on synthetic datasets and the real-world PopQA benchmark, employing GPT-2 and Llama-3.2-1B models.
