Table of Contents
Fetching ...

Mitigating attribute amplification in counterfactual image generation

Tian Xia, Mélanie Roschewitz, Fabio De Sousa Ribeiro, Charles Jones, Ben Glocker

TL;DR

This work shows that attribute amplification is caused by the use of hard labels in the counterfactual training process and proposes soft counterfactual fine-tuning to mitigate this issue, and makes an important advancement towards more faithful and unbiased causal modelling in medical imaging.

Abstract

Causal generative modelling is gaining interest in medical imaging due to its ability to answer interventional and counterfactual queries. Most work focuses on generating counterfactual images that look plausible, using auxiliary classifiers to enforce effectiveness of simulated interventions. We investigate pitfalls in this approach, discovering the issue of attribute amplification, where unrelated attributes are spuriously affected during interventions, leading to biases across protected characteristics and disease status. We show that attribute amplification is caused by the use of hard labels in the counterfactual training process and propose soft counterfactual fine-tuning to mitigate this issue. Our method substantially reduces the amplification effect while maintaining effectiveness of generated images, demonstrated on a large chest X-ray dataset. Our work makes an important advancement towards more faithful and unbiased causal modelling in medical imaging.

Mitigating attribute amplification in counterfactual image generation

TL;DR

This work shows that attribute amplification is caused by the use of hard labels in the counterfactual training process and proposes soft counterfactual fine-tuning to mitigate this issue, and makes an important advancement towards more faithful and unbiased causal modelling in medical imaging.

Abstract

Causal generative modelling is gaining interest in medical imaging due to its ability to answer interventional and counterfactual queries. Most work focuses on generating counterfactual images that look plausible, using auxiliary classifiers to enforce effectiveness of simulated interventions. We investigate pitfalls in this approach, discovering the issue of attribute amplification, where unrelated attributes are spuriously affected during interventions, leading to biases across protected characteristics and disease status. We show that attribute amplification is caused by the use of hard labels in the counterfactual training process and propose soft counterfactual fine-tuning to mitigate this issue. Our method substantially reduces the amplification effect while maintaining effectiveness of generated images, demonstrated on a large chest X-ray dataset. Our work makes an important advancement towards more faithful and unbiased causal modelling in medical imaging.
Paper Structure (16 sections, 5 figures, 2 tables, 1 algorithm)

This paper contains 16 sections, 5 figures, 2 tables, 1 algorithm.

Figures (5)

  • Figure 1: Generated CFs with (a) Hard-CFT and (b) Soft-CFT. First rows show original image $\mathbf{x}$ and CFs $\mathbf{\widetilde{x}}$; second rows show direct effect of CFs, i.e. $\mathbf{\widetilde{x}}-\mathbf{x}$.
  • Figure 2: Marginal distribution of image embeddings across PCA mode 3 of a multi-task model. This mode encodes changes in the race attribute. We plot distributions for subgroups of real data alongside distributions of counterfactual images when intervening on disease and sex attributes. When training with Hard-CFT (left) there is a clear distribution shift between real (blue) and counterfactual images (red). This shift is removed for both interventions when using our proposed Soft-CFT (right). These results suggest that Soft-CFT successfully mitigates attribute amplification and generates more faithful counterfactual images.
  • Figure A1: Illustration of how attribute amplification may violate the causal graph pre-defined for the DSCM which may lead to spurious correlations between protected characteristics and disease status encoded in the counterfactual images.
  • Figure A2: Generated CFs with (a) Hard-CFT and (b) Soft-CFT. Top rows show original image $\mathbf{x}$ and CFs $\mathbf{\widetilde{x}}$; bottom rows show direct effect of CFs, i.e. $\mathbf{\widetilde{x}}-\mathbf{x}$.
  • Figure A3: Marginal distribution of PCA modes of pretrained embeddings from a multi-task model predicted all attributes. We plot embeddings of real data along side with generated counterfactuals of various subgroups. We can see that when training with Hard-CFT (left) there is a distribution shift between real images and images after intervention (red). Conversely, this shift is mitigated when using our proposed soft counterfactual fine-tuning (Soft-CFT, right).