Table of Contents
Fetching ...

Learning to Focus: Causal Attention Distillation via Gradient-Guided Token Pruning

Yiju Guo, Wenkai Yang, Zexu Sun, Ning Ding, Zhiyuan Liu, Yankai Lin

TL;DR

This paper addresses the problem that large language models often misallocate attention due to distracting patterns during long-context reasoning. It introduces Learning to Focus (LeaF), a two-stage, causal-attention distillation framework that first identifies confounding tokens via gradient comparisons between a high-capacity teacher and a student, then prunes these tokens during distillation to align the student with the teacher on true causal context. LeaF uses span pruning to generate counterfactuals and employs a hybrid distillation objective that combines standard and counterfactual guidance, with response-splitting strategies to capture distractions at both instruction and generation levels. Across math, code, and multi-hop QA benchmarks, LeaF yields consistent accuracy gains over standard KD and improves interpretability by reducing attention to confounders, with only modest training overhead.

Abstract

Large language models (LLMs) have demonstrated significant improvements in contextual understanding. However, their ability to attend to truly critical information during long-context reasoning and generation still falls behind the pace. Specifically, our preliminary experiments reveal that certain distracting patterns can misdirect the model's attention during inference, and removing these patterns substantially improves reasoning accuracy and generation quality. We attribute this phenomenon to spurious correlations in the training data, which obstruct the model's capacity to infer authentic causal instruction-response relationships. This phenomenon may induce redundant reasoning processes, potentially resulting in significant inference overhead and, more critically, the generation of erroneous or suboptimal responses. To mitigate this, we introduce a two-stage framework called Learning to Focus (LeaF) leveraging intervention-based inference to disentangle confounding factors. In the first stage, LeaF employs gradient-based comparisons with an advanced teacher to automatically identify confounding tokens based on causal relationships in the training corpus. Then, in the second stage, it prunes these tokens during distillation to enact intervention, aligning the student's attention with the teacher's focus distribution on truly critical context tokens. Experimental results demonstrate that LeaF not only achieves an absolute improvement in various mathematical reasoning, code generation and multi-hop question answering benchmarks but also effectively suppresses attention to confounding tokens during inference, yielding a more interpretable and reliable reasoning model.

Learning to Focus: Causal Attention Distillation via Gradient-Guided Token Pruning

TL;DR

This paper addresses the problem that large language models often misallocate attention due to distracting patterns during long-context reasoning. It introduces Learning to Focus (LeaF), a two-stage, causal-attention distillation framework that first identifies confounding tokens via gradient comparisons between a high-capacity teacher and a student, then prunes these tokens during distillation to align the student with the teacher on true causal context. LeaF uses span pruning to generate counterfactuals and employs a hybrid distillation objective that combines standard and counterfactual guidance, with response-splitting strategies to capture distractions at both instruction and generation levels. Across math, code, and multi-hop QA benchmarks, LeaF yields consistent accuracy gains over standard KD and improves interpretability by reducing attention to confounders, with only modest training overhead.

Abstract

Large language models (LLMs) have demonstrated significant improvements in contextual understanding. However, their ability to attend to truly critical information during long-context reasoning and generation still falls behind the pace. Specifically, our preliminary experiments reveal that certain distracting patterns can misdirect the model's attention during inference, and removing these patterns substantially improves reasoning accuracy and generation quality. We attribute this phenomenon to spurious correlations in the training data, which obstruct the model's capacity to infer authentic causal instruction-response relationships. This phenomenon may induce redundant reasoning processes, potentially resulting in significant inference overhead and, more critically, the generation of erroneous or suboptimal responses. To mitigate this, we introduce a two-stage framework called Learning to Focus (LeaF) leveraging intervention-based inference to disentangle confounding factors. In the first stage, LeaF employs gradient-based comparisons with an advanced teacher to automatically identify confounding tokens based on causal relationships in the training corpus. Then, in the second stage, it prunes these tokens during distillation to enact intervention, aligning the student's attention with the teacher's focus distribution on truly critical context tokens. Experimental results demonstrate that LeaF not only achieves an absolute improvement in various mathematical reasoning, code generation and multi-hop question answering benchmarks but also effectively suppresses attention to confounding tokens during inference, yielding a more interpretable and reliable reasoning model.

Paper Structure

This paper contains 48 sections, 9 equations, 13 figures, 10 tables.

Figures (13)

  • Figure 1: Accuracy improvements achieved by removing confounding tokens from small models on the math and code training corpora. The results demonstrate a significant increase in performance, with over 20% improvement on the math corpus and more than 10% on the code corpus. (For further details on these categories, see Appendix \ref{['apex: pre-examination']}.)
  • Figure 2: Comparison of reasoning before and after pruning distracting patterns. Blue-shaded regions indicate pruned confounding tokens. Pink highlights mark areas that require focus, while blue highlights show where excessive attention caused errors.
  • Figure 3: Causal graph of the reasoning process. $X$ represents the input prompt, and $Y$ denotes the model's output. A subset of tokens in $X$, identified as confounding tokens ($A$), introduces spurious correlations that disrupt the reasoning process. Our method detects and masks $A$, effectively eliminating the spurious edge from $A$ to $Y$ and restoring the true causal dependency.
  • Figure 4: Method Overview. The training pipeline comprises two key stages: (1) Confounding Token Detection: gradient-based comparisons between an advanced teacher model and the student model are used to identify confounding tokens in the training samples and constructs counterfactual samples by pruning these tokens; and (2) Causal Attention Distillation: prune identified confounders respectively during training to align the student’s attention with the teacher’s and capture casual relationships. This targeted intervention steers the model toward actual causal dependencies, enhancing both robustness and interpretability.
  • Figure 5: Illustration of Collective Pruning and Span Pruning.
  • ...and 8 more figures