Table of Contents
Fetching ...

Did Models Sufficient Learn? Attribution-Guided Training via Subset-Selected Counterfactual Augmentation

Yannan Chen, Ruoyu Chen, Bin Zeng, Wei Wang, Shiming Liu, Qunli Zhang, Zheng Hu, Laiyuan Wang, Yaowei Wang, Xiaochun Cao

TL;DR

The paper attacks shortcut learning in visual models by integrating attribution-guided counterfactual reasoning into training. It introduces Subset-Selected Counterfactual Augmentation (SS-CA) and Counterfactual LIMA to identify minimal regions whose removal would flip predictions, then replaces those regions with plausible background content to generate hard, yet semantically consistent, augmented samples. Through a joint objective that combines standard supervision with attribution-guided augmentation and a submodular greedy search, SS-CA promotes learning more complete and causal decision rules. Experiments across ImageNet variants and multiple backbones demonstrate consistent gains in in-distribution accuracy and out-of-distribution robustness, including ImageNet-R and ImageNet-S, and improved performance under common corruptions, underscoring the practical impact of combining interpretability with training-time debiasing.

Abstract

In current visual model training, models often rely on only limited sufficient causes for their predictions, which makes them sensitive to distribution shifts or the absence of key features. Attribution methods can accurately identify a model's critical regions. However, masking these areas to create counterfactuals often causes the model to misclassify the target, while humans can still easily recognize it. This divergence highlights that the model's learned dependencies may not be sufficiently causal. To address this issue, we propose Subset-Selected Counterfactual Augmentation (SS-CA), which integrates counterfactual explanations directly into the training process for targeted intervention. Building on the subset-selection-based LIMA attribution method, we develop Counterfactual LIMA to identify minimal spatial region sets whose removal can selectively alter model predictions. Leveraging these attributions, we introduce a data augmentation strategy that replaces the identified regions with natural background, and we train the model jointly on both augmented and original samples to mitigate incomplete causal learning. Extensive experiments across multiple ImageNet variants show that SS-CA improves generalization on in-distribution (ID) test data and achieves superior performance on out-of-distribution (OOD) benchmarks such as ImageNet-R and ImageNet-S. Under perturbations including noise, models trained with SS-CA also exhibit enhanced generalization, demonstrating that our approach effectively uses interpretability insights to correct model deficiencies and improve both performance and robustness.

Did Models Sufficient Learn? Attribution-Guided Training via Subset-Selected Counterfactual Augmentation

TL;DR

The paper attacks shortcut learning in visual models by integrating attribution-guided counterfactual reasoning into training. It introduces Subset-Selected Counterfactual Augmentation (SS-CA) and Counterfactual LIMA to identify minimal regions whose removal would flip predictions, then replaces those regions with plausible background content to generate hard, yet semantically consistent, augmented samples. Through a joint objective that combines standard supervision with attribution-guided augmentation and a submodular greedy search, SS-CA promotes learning more complete and causal decision rules. Experiments across ImageNet variants and multiple backbones demonstrate consistent gains in in-distribution accuracy and out-of-distribution robustness, including ImageNet-R and ImageNet-S, and improved performance under common corruptions, underscoring the practical impact of combining interpretability with training-time debiasing.

Abstract

In current visual model training, models often rely on only limited sufficient causes for their predictions, which makes them sensitive to distribution shifts or the absence of key features. Attribution methods can accurately identify a model's critical regions. However, masking these areas to create counterfactuals often causes the model to misclassify the target, while humans can still easily recognize it. This divergence highlights that the model's learned dependencies may not be sufficiently causal. To address this issue, we propose Subset-Selected Counterfactual Augmentation (SS-CA), which integrates counterfactual explanations directly into the training process for targeted intervention. Building on the subset-selection-based LIMA attribution method, we develop Counterfactual LIMA to identify minimal spatial region sets whose removal can selectively alter model predictions. Leveraging these attributions, we introduce a data augmentation strategy that replaces the identified regions with natural background, and we train the model jointly on both augmented and original samples to mitigate incomplete causal learning. Extensive experiments across multiple ImageNet variants show that SS-CA improves generalization on in-distribution (ID) test data and achieves superior performance on out-of-distribution (OOD) benchmarks such as ImageNet-R and ImageNet-S. Under perturbations including noise, models trained with SS-CA also exhibit enhanced generalization, demonstrating that our approach effectively uses interpretability insights to correct model deficiencies and improve both performance and robustness.

Paper Structure

This paper contains 13 sections, 10 equations, 6 figures, 5 tables, 1 algorithm.

Figures (6)

  • Figure 1: Conceptual motivation for identifying and mitigating shortcut learning. The top panel diagnoses the problem of a model learning limited sufficient causes, where it relies only on a single feature (Cause 1). The bottom panel illustrates our solution, where counterfactual augmentation refines the model's decision boundary for more robust recognition.
  • Figure 2: The overall framework of Subset-Selected Counterfactual Augmentation (SS-CA). It forms a closed training loop with three stages: (1) Conventional Training, where a factual image (e.g., “Robin”) is fed into the network to obtain an initial prediction; (2) Counterfactual Explanation, which identifies a minimal set of regions whose removal flips the prediction to a counterfactual class (e.g., “Heron”); and (3) Attribution-guided Augmentation, which uses the counterfactual mask to replace these regions with random background, yielding a hard augmented sample that retains its original ground-truth label (“Robin”) and is fed back into training.
  • Figure 3: Visualization of the SS-CA training loop on ImageNet-100. The image demonstrates our Counterfactual LIMA identifying spurious features for removal. The accompanying "Deletion Curve" plot confirms the resulting successful prediction flip.
  • Figure 4: Illustration of the SS-CA framework operating on TinyImageNet-200. The framework demonstrates a robust ability to identify shortcut cues, even on low-resolution 64x64 images. The accompanying "Deletion Curve" plot validates the successful prediction flip resulting from the removal of these regions.
  • Figure 5: Demonstrating SS-CA's scalability during the training process on ImageNet-1k. The figure shows our submodular search accurately generating "hard" counterfactual samples used for debiasing on this large-scale dataset. The accompanying "Deletion Curve" confirms the successful prediction flip.
  • ...and 1 more figures