Disentangling Memory and Reasoning Ability in Large Language Models
Mingyu Jin, Weidi Luo, Sitao Cheng, Xinyi Wang, Wenyue Hua, Ruixiang Tang, William Yang Wang, Yongfeng Zhang
TL;DR
This work tackles the opacity of large language model inference by explicitly separating memory recall from reasoning using two trainable tokens, memory and reason. It introduces a two-stage framework: data generation with decoupled memory and reasoning steps and subsequent LLM training that conditions on these tokens to disentangle retrieval from inference. Across StrategyQA, CommonsenseQA, and TruthfulQA, the approach yields competitive performance, including surpassing GPT-4o on TruthfulQA and narrowing the gap to GPT-4o on other benchmarks, while improving interpretability through labeled step-by-step reasoning. Ablation and analysis show that the memory/reason separation predominantly affects reasoning accuracy, with attention patterns validating the importance of the two tokens for guiding inference.
Abstract
Large Language Models (LLMs) have demonstrated strong performance in handling complex tasks requiring both extensive knowledge and reasoning abilities. However, the existing LLM inference pipeline operates as an opaque process without explicit separation between knowledge retrieval and reasoning steps, making the model's decision-making process unclear and disorganized. This ambiguity can lead to issues such as hallucinations and knowledge forgetting, which significantly impact the reliability of LLMs in high-stakes domains. In this paper, we propose a new inference paradigm that decomposes the complex inference process into two distinct and clear actions: (1) memory recall: which retrieves relevant knowledge, and (2) reasoning: which performs logical steps based on the recalled knowledge. To facilitate this decomposition, we introduce two special tokens memory and reason, guiding the model to distinguish between steps that require knowledge retrieval and those that involve reasoning. Our experiment results show that this decomposition not only improves model performance but also enhances the interpretability of the inference process, enabling users to identify sources of error and refine model responses effectively. The code is available at https://github.com/MingyuJ666/Disentangling-Memory-and-Reasoning.
