Back Attention: Understanding and Enhancing Multi-Hop Reasoning in Large Language Models
Zeping Yu, Yonatan Belinkov, Sophia Ananiadou
TL;DR
This work tackles latent multi-hop reasoning in large language models by introducing logit flow, an interpretability method that traces logits across layers and positions to reveal four stages in single-hop prediction and the causes of failures in two-hop reasoning. Based on these insights, the authors propose back attention, a lightweight mechanism enabling lower layers to incorporate higher-layer features by querying from a lower layer and using higher-layer keys/values, which boosts performance even for a 1-layer transformer to rival a 2-layer model. Empirical results across four LLMs and five reasoning datasets show consistent accuracy gains, and a focused arithmetic task demonstrates notable improvements with minimal parameter overhead. The work advances understanding of latent reasoning in LLMs and provides a practical technique to enhance multi-hop reasoning with broad applicability to fine-tuning pretrained models and interpretability analysis.
Abstract
We investigate how large language models perform latent multi-hop reasoning in prompts like "Wolfgang Amadeus Mozart's mother's spouse is". To analyze this process, we introduce logit flow, an interpretability method that traces how logits propagate across layers and positions toward the final prediction. Using logit flow, we identify four distinct stages in single-hop knowledge prediction: (A) entity subject enrichment, (B) entity attribute extraction, (C) relation subject enrichment, and (D) relation attribute extraction. Extending this analysis to multi-hop reasoning, we find that failures often stem from the relation attribute extraction stage, where conflicting logits reduce prediction accuracy. To address this, we propose back attention, a novel mechanism that enables lower layers to leverage higher-layer hidden states from different positions during attention computation. With back attention, a 1-layer transformer achieves the performance of a 2-layer transformer. Applied to four LLMs, back attention improves accuracy on five reasoning datasets, demonstrating its effectiveness in enhancing latent multi-hop reasoning ability.
