Table of Contents
Fetching ...

Causal Diffusion Autoencoders: Toward Counterfactual Generation via Diffusion Probabilistic Models

Aneesh Komanduri, Chen Zhao, Feng Chen, Xintao Wu

TL;DR

CausalDiffAE addresses the gap between diffusion models and causal, controllable representations by learning a disentangled latent SCM z_causal and a stochastic latent x_T, enabling counterfactual generation via do-interventions. The framework couples a causal encoding with a DDIM-based decoder and a variational objective that includes a label-aligned prior, achieving high disentanglement and accurate, realistic counterfactuals. Empirical results on MorphoMNIST, Pendulum, and CausalCircuit demonstrate superior latent disentanglement (DCI) and counterfactual effectiveness (MAE) over baselines, with a weak supervision variant showing robust performance under limited labels. This approach paves the way for reliable, causally grounded diffusion-based generation with practical implications for data augmentation and counterfactual reasoning in vision tasks.

Abstract

Diffusion probabilistic models (DPMs) have become the state-of-the-art in high-quality image generation. However, DPMs have an arbitrary noisy latent space with no interpretable or controllable semantics. Although there has been significant research effort to improve image sample quality, there is little work on representation-controlled generation using diffusion models. Specifically, causal modeling and controllable counterfactual generation using DPMs is an underexplored area. In this work, we propose CausalDiffAE, a diffusion-based causal representation learning framework to enable counterfactual generation according to a specified causal model. Our key idea is to use an encoder to extract high-level semantically meaningful causal variables from high-dimensional data and model stochastic variation using reverse diffusion. We propose a causal encoding mechanism that maps high-dimensional data to causally related latent factors and parameterize the causal mechanisms among latent factors using neural networks. To enforce the disentanglement of causal variables, we formulate a variational objective and leverage auxiliary label information in a prior to regularize the latent space. We propose a DDIM-based counterfactual generation procedure subject to do-interventions. Finally, to address the limited label supervision scenario, we also study the application of CausalDiffAE when a part of the training data is unlabeled, which also enables granular control over the strength of interventions in generating counterfactuals during inference. We empirically show that CausalDiffAE learns a disentangled latent space and is capable of generating high-quality counterfactual images.

Causal Diffusion Autoencoders: Toward Counterfactual Generation via Diffusion Probabilistic Models

TL;DR

CausalDiffAE addresses the gap between diffusion models and causal, controllable representations by learning a disentangled latent SCM z_causal and a stochastic latent x_T, enabling counterfactual generation via do-interventions. The framework couples a causal encoding with a DDIM-based decoder and a variational objective that includes a label-aligned prior, achieving high disentanglement and accurate, realistic counterfactuals. Empirical results on MorphoMNIST, Pendulum, and CausalCircuit demonstrate superior latent disentanglement (DCI) and counterfactual effectiveness (MAE) over baselines, with a weak supervision variant showing robust performance under limited labels. This approach paves the way for reliable, causally grounded diffusion-based generation with practical implications for data augmentation and counterfactual reasoning in vision tasks.

Abstract

Diffusion probabilistic models (DPMs) have become the state-of-the-art in high-quality image generation. However, DPMs have an arbitrary noisy latent space with no interpretable or controllable semantics. Although there has been significant research effort to improve image sample quality, there is little work on representation-controlled generation using diffusion models. Specifically, causal modeling and controllable counterfactual generation using DPMs is an underexplored area. In this work, we propose CausalDiffAE, a diffusion-based causal representation learning framework to enable counterfactual generation according to a specified causal model. Our key idea is to use an encoder to extract high-level semantically meaningful causal variables from high-dimensional data and model stochastic variation using reverse diffusion. We propose a causal encoding mechanism that maps high-dimensional data to causally related latent factors and parameterize the causal mechanisms among latent factors using neural networks. To enforce the disentanglement of causal variables, we formulate a variational objective and leverage auxiliary label information in a prior to regularize the latent space. We propose a DDIM-based counterfactual generation procedure subject to do-interventions. Finally, to address the limited label supervision scenario, we also study the application of CausalDiffAE when a part of the training data is unlabeled, which also enables granular control over the strength of interventions in generating counterfactuals during inference. We empirically show that CausalDiffAE learns a disentangled latent space and is capable of generating high-quality counterfactual images.
Paper Structure (26 sections, 37 equations, 7 figures, 4 tables, 2 algorithms)

This paper contains 26 sections, 37 equations, 7 figures, 4 tables, 2 algorithms.

Figures (7)

  • Figure 1: CausalDiffAE Framework. The left side details the training process of CausalDiffAE by encoding to causal representation $\mathbf{z}_{\text{causal}}$ and using a conditional DDIM decoder conditioned on $\mathbf{z}_{\text{causal}}$ and $\mathbf{x}_T$ for image reconstruction. The right side shows the DDIM-based counterfactual generation procedure using a trained CausalDiffAE model.
  • Figure 2: Counterfactual trajectories generated by CausalDiffAE and baseline models for (a) MorphoMNIST and (b) Pendulum datasets. We observe that CausalDiffAE generates much more accurate counterfactuals upon interventions on causal factors compared to baselines.
  • Figure 3: CausalCircuit results (Orig: $y_1 = 0.02, y_2 = 0.03, y_3 = 0.04, y_4 = 0.14$)
  • Figure 4: MorphoMNIST Weak Supervision
  • Figure 5: CausalDiffAE generated counterfactuals (MorphoMNIST)
  • ...and 2 more figures