Table of Contents
Fetching ...

DiffusionCounterfactuals: Inferring High-dimensional Counterfactuals with Guidance of Causal Representations

Jiageng Zhu, Hanchen Xie, Jiazhi Li, Wael Abd-Almageed

TL;DR

DiffusionCounterfactuals tackles high-dimensional counterfactual inference by uniting diffusion models with causal representations. The framework jointly learns reconstruction and causal mechanisms using a generator $g_\theta$, a causal projector $h_\phi$, and an NSCM $s_\phi$, optimizing $\log p(x)$ and $\log p(\mathcal{Z}|x)$. Counterfactual samples are produced via a gradient-guided reverse diffusion step that conditions on the intervened causal factors, with a self-adjusted scalar $\lambda_t$ to balance guidance and diffusion uncertainty. Empirical results across six datasets show improved counterfactual quality and causal consistency (ACM) and competitive image fidelity (FID/sFID, PSNR) compared to state-of-the-art baselines, highlighting the practical utility for scenario analysis and decision-making in high-dimensional domains.

Abstract

Accurate estimation of counterfactual outcomes in high-dimensional data is crucial for decision-making and understanding causal relationships and intervention outcomes in various domains, including healthcare, economics, and social sciences. However, existing methods often struggle to generate accurate and consistent counterfactuals, particularly when the causal relationships are complex. We propose a novel framework that incorporates causal mechanisms and diffusion models to generate high-quality counterfactual samples guided by causal representation. Our approach introduces a novel, theoretically grounded training and sampling process that enables the model to consistently generate accurate counterfactual high-dimensional data under multiple intervention steps. Experimental results on various synthetic and real benchmarks demonstrate the proposed approach outperforms state-of-the-art methods in generating accurate and high-quality counterfactuals, using different evaluation metrics.

DiffusionCounterfactuals: Inferring High-dimensional Counterfactuals with Guidance of Causal Representations

TL;DR

DiffusionCounterfactuals tackles high-dimensional counterfactual inference by uniting diffusion models with causal representations. The framework jointly learns reconstruction and causal mechanisms using a generator , a causal projector , and an NSCM , optimizing and . Counterfactual samples are produced via a gradient-guided reverse diffusion step that conditions on the intervened causal factors, with a self-adjusted scalar to balance guidance and diffusion uncertainty. Empirical results across six datasets show improved counterfactual quality and causal consistency (ACM) and competitive image fidelity (FID/sFID, PSNR) compared to state-of-the-art baselines, highlighting the practical utility for scenario analysis and decision-making in high-dimensional domains.

Abstract

Accurate estimation of counterfactual outcomes in high-dimensional data is crucial for decision-making and understanding causal relationships and intervention outcomes in various domains, including healthcare, economics, and social sciences. However, existing methods often struggle to generate accurate and consistent counterfactuals, particularly when the causal relationships are complex. We propose a novel framework that incorporates causal mechanisms and diffusion models to generate high-quality counterfactual samples guided by causal representation. Our approach introduces a novel, theoretically grounded training and sampling process that enables the model to consistently generate accurate counterfactual high-dimensional data under multiple intervention steps. Experimental results on various synthetic and real benchmarks demonstrate the proposed approach outperforms state-of-the-art methods in generating accurate and high-quality counterfactuals, using different evaluation metrics.
Paper Structure (21 sections, 21 equations, 18 figures, 5 tables)

This paper contains 21 sections, 21 equations, 18 figures, 5 tables.

Figures (18)

  • Figure 1: Guided by the causality, the diffusion generation can answer “what if” questions, i.e., counterfactual inference, in the input space.
  • Figure 2: DiffusionCounterfactuals Framework: Training and Counterfactual Inference. (a) The DiffusionCounterfactuals training process learns a diffusion reconstruction and discovers the causal factors along with their corresponding causal relations. (b) The counterfactual inference process generates counterfactual images by conditioning on estimated intervention outcomes of specific factors in the generative factor space, using the learned model.
  • Figure 3: Illustration of the effects of different $\lambda$ values on guided sampling: (a) A target sample (red box) is chosen, and the distribution of the difference in predicted generative factors (Pendulum angle and Light position) between the target sample and other samples is modeled as a 2D Gaussian using a causal predictor. (b) Focusing on the Light position attribute, the orange area represents predictions from the noisy target sample, while the blue area represents predictions from other samples. Extreme $\lambda$ values fail to guide generation correctly. In contrast, our proposed self-adjusted $\lambda_t$ adjusts the strength for proper guidance.
  • Figure 4: Sequential counterfactual generations results. We evaluate the sequential counterfactual generated samples in an autoregressive way. More results are included in Appendix.
  • Figure 5: Attribute Consistency Metrics (ACM) calculation pipeline.
  • ...and 13 more figures