Did Models Sufficient Learn? Attribution-Guided Training via Subset-Selected Counterfactual Augmentation
Yannan Chen, Ruoyu Chen, Bin Zeng, Wei Wang, Shiming Liu, Qunli Zhang, Zheng Hu, Laiyuan Wang, Yaowei Wang, Xiaochun Cao
TL;DR
The paper attacks shortcut learning in visual models by integrating attribution-guided counterfactual reasoning into training. It introduces Subset-Selected Counterfactual Augmentation (SS-CA) and Counterfactual LIMA to identify minimal regions whose removal would flip predictions, then replaces those regions with plausible background content to generate hard, yet semantically consistent, augmented samples. Through a joint objective that combines standard supervision with attribution-guided augmentation and a submodular greedy search, SS-CA promotes learning more complete and causal decision rules. Experiments across ImageNet variants and multiple backbones demonstrate consistent gains in in-distribution accuracy and out-of-distribution robustness, including ImageNet-R and ImageNet-S, and improved performance under common corruptions, underscoring the practical impact of combining interpretability with training-time debiasing.
Abstract
In current visual model training, models often rely on only limited sufficient causes for their predictions, which makes them sensitive to distribution shifts or the absence of key features. Attribution methods can accurately identify a model's critical regions. However, masking these areas to create counterfactuals often causes the model to misclassify the target, while humans can still easily recognize it. This divergence highlights that the model's learned dependencies may not be sufficiently causal. To address this issue, we propose Subset-Selected Counterfactual Augmentation (SS-CA), which integrates counterfactual explanations directly into the training process for targeted intervention. Building on the subset-selection-based LIMA attribution method, we develop Counterfactual LIMA to identify minimal spatial region sets whose removal can selectively alter model predictions. Leveraging these attributions, we introduce a data augmentation strategy that replaces the identified regions with natural background, and we train the model jointly on both augmented and original samples to mitigate incomplete causal learning. Extensive experiments across multiple ImageNet variants show that SS-CA improves generalization on in-distribution (ID) test data and achieves superior performance on out-of-distribution (OOD) benchmarks such as ImageNet-R and ImageNet-S. Under perturbations including noise, models trained with SS-CA also exhibit enhanced generalization, demonstrating that our approach effectively uses interpretability insights to correct model deficiencies and improve both performance and robustness.
