Table of Contents
Fetching ...

Counterfactual Reasoning for Multi-Label Image Classification via Patching-Based Training

Ming-Kun Xie, Jia-Hao Xiao, Pei Peng, Gang Niu, Masashi Sugiyama, Sheng-Jun Huang

TL;DR

This work casts multi-label image classification as a causal problem where co-occurrence between objects acts as a mediator $O$ that can both aid and misguide predictions through the path $X\rightarrow O\rightarrow Y$. It introduces a counterfactual total direct effect (TDE) objective to bolster the direct influence of the target object while suppressing mediator-induced errors, realized via patching-based inference (PAT-I) and patching-based training (PAT-T). Empirical results on MS-COCO, VOC 2007, and VG-256 demonstrate state-of-the-art performance, with ablations validating the contribution of TDE and patching. The approach offers a flexible, plug-in framework that improves robustness to label co-occurrence without requiring substantial changes to existing architectures.

Abstract

The key to multi-label image classification (MLC) is to improve model performance by leveraging label correlations. Unfortunately, it has been shown that overemphasizing co-occurrence relationships can cause the overfitting issue of the model, ultimately leading to performance degradation. In this paper, we provide a causal inference framework to show that the correlative features caused by the target object and its co-occurring objects can be regarded as a mediator, which has both positive and negative impacts on model predictions. On the positive side, the mediator enhances the recognition performance of the model by capturing co-occurrence relationships; on the negative side, it has the harmful causal effect that causes the model to make an incorrect prediction for the target object, even when only co-occurring objects are present in an image. To address this problem, we propose a counterfactual reasoning method to measure the total direct effect, achieved by enhancing the direct effect caused only by the target object. Due to the unknown location of the target object, we propose patching-based training and inference to accomplish this goal, which divides an image into multiple patches and identifies the pivot patch that contains the target object. Experimental results on multiple benchmark datasets with diverse configurations validate that the proposed method can achieve state-of-the-art performance.

Counterfactual Reasoning for Multi-Label Image Classification via Patching-Based Training

TL;DR

This work casts multi-label image classification as a causal problem where co-occurrence between objects acts as a mediator that can both aid and misguide predictions through the path . It introduces a counterfactual total direct effect (TDE) objective to bolster the direct influence of the target object while suppressing mediator-induced errors, realized via patching-based inference (PAT-I) and patching-based training (PAT-T). Empirical results on MS-COCO, VOC 2007, and VG-256 demonstrate state-of-the-art performance, with ablations validating the contribution of TDE and patching. The approach offers a flexible, plug-in framework that improves robustness to label co-occurrence without requiring substantial changes to existing architectures.

Abstract

The key to multi-label image classification (MLC) is to improve model performance by leveraging label correlations. Unfortunately, it has been shown that overemphasizing co-occurrence relationships can cause the overfitting issue of the model, ultimately leading to performance degradation. In this paper, we provide a causal inference framework to show that the correlative features caused by the target object and its co-occurring objects can be regarded as a mediator, which has both positive and negative impacts on model predictions. On the positive side, the mediator enhances the recognition performance of the model by capturing co-occurrence relationships; on the negative side, it has the harmful causal effect that causes the model to make an incorrect prediction for the target object, even when only co-occurring objects are present in an image. To address this problem, we propose a counterfactual reasoning method to measure the total direct effect, achieved by enhancing the direct effect caused only by the target object. Due to the unknown location of the target object, we propose patching-based training and inference to accomplish this goal, which divides an image into multiple patches and identifies the pivot patch that contains the target object. Experimental results on multiple benchmark datasets with diverse configurations validate that the proposed method can achieve state-of-the-art performance.
Paper Structure (25 sections, 8 equations, 6 figures, 6 tables, 2 algorithms)

This paper contains 25 sections, 8 equations, 6 figures, 6 tables, 2 algorithms.

Figures (6)

  • Figure 1: Comparison results between InT (each class has a visual backbone) and DeT (all classes share a visual backbone) in terms of conditional TPR (a) and FPR (b) for given class pairs on MS-COCO. DeT achieves greater conditional TPR and FPR than InT, meaning the co-occurrence relationships can be encoded by feature representations, exerting both positive and negative effects on model predictions.
  • Figure 2: An illustration of the causal graph. The dashed line indicates these two objects co-occurring in an image.
  • Figure 3: An illustration of TDE inference for MLC.
  • Figure 4: Visualization of model predictions on MS-COCO. Ori denotes the probability predicted by PAT-T model on the original image; Pat denotes the probability predicted by PAT-T model on patches. The weights of each patch are reported in the four-grid.
  • Figure 5: Conditional FPR and conditional TPR of the label pairs with co-occurrence probabilities larger than 0.2 on MS-COCO. The pair indices are sorted according to the performance of Baseline. PAT-T outperforms Baseline in around 76% and 77% percentage of label pairs in terms of CFPR and CTPR.
  • ...and 1 more figures