Table of Contents
Fetching ...

Mitigating Shortcut Learning with Diffusion Counterfactuals and Diverse Ensembles

Luca Scimeca, Alexander Rubinstein, Damien Teney, Seong Joon Oh, Yoshua Bengio

TL;DR

This work addresses shortcut learning, where models exploit spurious cues in training data, by proposing DiffDiv, a diffusion-based ensemble strategy. The method trains a diffusion probabilistic model on the target data and uses diffusion-generated counterfactuals to drive model disagreement within an ensemble, removing reliance on shortcuts without needing labeled out-of-distribution data. Empirical results on ColorDSprites, UTKFace, and CelebA show that diffusion-guided diversification can achieve bias mitigation and ensemble diversity on par with data-dependent approaches, provided the diffusion model is trained to an appropriate, intermediate fidelity and early stopping is used. Overall, DiffDiv offers a scalable, data-efficient route to debias models and improve generalization by leveraging synthetic counterfactuals from diffusion models.

Abstract

Spurious correlations in the data, where multiple cues are predictive of the target labels, often lead to a phenomenon known as shortcut learning, where a model relies on erroneous, easy-to-learn cues while ignoring reliable ones. In this work, we propose DiffDiv an ensemble diversification framework exploiting Diffusion Probabilistic Models (DPMs) to mitigate this form of bias. We show that at particular training intervals, DPMs can generate images with novel feature combinations, even when trained on samples displaying correlated input features. We leverage this crucial property to generate synthetic counterfactuals to increase model diversity via ensemble disagreement. We show that DPM-guided diversification is sufficient to remove dependence on shortcut cues, without a need for additional supervised signals. We further empirically quantify its efficacy on several diversification objectives, and finally show improved generalization and diversification on par with prior work that relies on auxiliary data collection.

Mitigating Shortcut Learning with Diffusion Counterfactuals and Diverse Ensembles

TL;DR

This work addresses shortcut learning, where models exploit spurious cues in training data, by proposing DiffDiv, a diffusion-based ensemble strategy. The method trains a diffusion probabilistic model on the target data and uses diffusion-generated counterfactuals to drive model disagreement within an ensemble, removing reliance on shortcuts without needing labeled out-of-distribution data. Empirical results on ColorDSprites, UTKFace, and CelebA show that diffusion-guided diversification can achieve bias mitigation and ensemble diversity on par with data-dependent approaches, provided the diffusion model is trained to an appropriate, intermediate fidelity and early stopping is used. Overall, DiffDiv offers a scalable, data-efficient route to debias models and improve generalization by leveraging synthetic counterfactuals from diffusion models.

Abstract

Spurious correlations in the data, where multiple cues are predictive of the target labels, often lead to a phenomenon known as shortcut learning, where a model relies on erroneous, easy-to-learn cues while ignoring reliable ones. In this work, we propose DiffDiv an ensemble diversification framework exploiting Diffusion Probabilistic Models (DPMs) to mitigate this form of bias. We show that at particular training intervals, DPMs can generate images with novel feature combinations, even when trained on samples displaying correlated input features. We leverage this crucial property to generate synthetic counterfactuals to increase model diversity via ensemble disagreement. We show that DPM-guided diversification is sufficient to remove dependence on shortcut cues, without a need for additional supervised signals. We further empirically quantify its efficacy on several diversification objectives, and finally show improved generalization and diversification on par with prior work that relies on auxiliary data collection.
Paper Structure (12 sections, 9 equations, 4 figures, 1 table)

This paper contains 12 sections, 9 equations, 4 figures, 1 table.

Figures (4)

  • Figure 1: DiffDiv: We sample from a DPM to generate synthetic counterfactuals showcasing emergent novel feature combinations. These samples are then utilized to build a diverse model ensemble via different ensemble disagreement objectives.
  • Figure 2: DPM training and counterfactual generation. While training on images showcasing a correlated set of features (left columns), DPM samples at appropriate fidelity levels can generate novel objects beyond the observed feature combinations (marked right-hand side images).
  • Figure 3: ood sample frequency for DPMs trained at different fidelities on ColorDSprites.
  • Figure 4: Output diversity using samples from diffusion models at different fidelities, compared to using real ood samples (ood), or without diversification (BS) in ColorDSPrites. All results in \ref{['sup:fig:fidelity_vs_diversity']}.