Table of Contents
Fetching ...

Causal Representation Learning with Optimal Compression under Complex Treatments

Wanting Liang, Haoang Chi, Zhiheng Zhang

Abstract

Estimating Individual Treatment Effects (ITE) in multi-treatment scenarios faces two critical challenges: the Hyperparameter Selection Dilemma for balancing weights and the Curse of Dimensionality in computational scalability. This paper derives a novel multi-treatment generalization bound and proposes a theoretical estimator for the optimal balancing weight $α$, eliminating expensive heuristic tuning. We investigate three balancing strategies: Pairwise, One-vs-All (OVA), and Treatment Aggregation. While OVA achieves superior precision in low-dimensional settings, our proposed Treatment Aggregation ensures both accuracy and O(1) scalability as the treatment space expands. Furthermore, we extend our framework to a generative architecture, Multi-Treatment CausalEGM, which preserves the Wasserstein geodesic structure of the treatment manifold. Experiments on semi-synthetic and image datasets demonstrate that our approach significantly outperforms traditional models in estimation accuracy and efficiency, particularly in large-scale intervention scenarios.

Causal Representation Learning with Optimal Compression under Complex Treatments

Abstract

Estimating Individual Treatment Effects (ITE) in multi-treatment scenarios faces two critical challenges: the Hyperparameter Selection Dilemma for balancing weights and the Curse of Dimensionality in computational scalability. This paper derives a novel multi-treatment generalization bound and proposes a theoretical estimator for the optimal balancing weight , eliminating expensive heuristic tuning. We investigate three balancing strategies: Pairwise, One-vs-All (OVA), and Treatment Aggregation. While OVA achieves superior precision in low-dimensional settings, our proposed Treatment Aggregation ensures both accuracy and O(1) scalability as the treatment space expands. Furthermore, we extend our framework to a generative architecture, Multi-Treatment CausalEGM, which preserves the Wasserstein geodesic structure of the treatment manifold. Experiments on semi-synthetic and image datasets demonstrate that our approach significantly outperforms traditional models in estimation accuracy and efficiency, particularly in large-scale intervention scenarios.
Paper Structure (52 sections, 7 theorems, 53 equations, 8 figures, 1 algorithm)

This paper contains 52 sections, 7 theorems, 53 equations, 8 figures, 1 algorithm.

Key Result

Lemma 2.2

Let $f(\Phi,h):=\widehat{\epsilon}_F(\Phi,h)$ and $g(\Phi):=\widehat{\mathcal{R}}_{\mathcal{S}}(\Phi)$. Consider the constrained problem eq:constrained with budget $\rho$. Assume (i) the feasible set $\{(\Phi,h): g(\Phi)\le\rho\}$ is nonempty and admits a strictly feasible point (Slater condition),

Figures (8)

  • Figure 1: Performance Comparison (PEHE). (a) All methods outperform baseline at $K=4$. (b) At $K=20$, Pairwise degrades while Aggregation remains robust. (See Appendix \ref{['app:efficiency']}, Figure \ref{['fig:efficiency_appendix']} for training efficiency)
  • Figure 2: Performance and Efficiency Analysis on Digits Dataset. (a) The estimated ADRF (blue) closely tracks the ground truth (red). (b) Dual-axis comparison showing PEHE error (bars) and Training Time (line).
  • Figure 3: Geometric Validation on Hierarchical Treatments. (a) The learned embeddings spontaneously recover the underlying tree structure, placing the Root centrally and separating the L/R branches. (b) Counterfactual interpolation from Leaf LL($Y=-3$) to Leaf RR ($Y=+3$) respects the causal topology by passing through the Root's effect region ($Y \approx 0$), whereas the linear baseline (grey) ignores the structure.
  • Figure 4: Detailed Architecture of Multi-Treatment CausalEGM. The left panel shows the original binary design, while the right panel highlights our proposed extensions (Embedding Layer and Softmax Activation) for complex treatment regimes.
  • Figure 5: Training Time Efficiency at $K=20$. The Pairwise strategy incurs distinct computational costs due to $\binom{K}{2}$ constraints.
  • ...and 3 more figures

Theorems & Definitions (9)

  • Remark 2.1
  • Lemma 2.2: Penalty--constraint equivalence
  • Remark 2.3: Geometric Extension: Wasserstein Manifolds and Geodesic Causal Inference
  • Lemma 3.2: Multi-treatment generalization bound (schematic form)
  • Lemma 3.3: Profile score for $\alpha$
  • Theorem 3.5: Finite-sample deviation bound for $\widehat{\alpha}_{\mathcal{S}}$
  • Corollary 3.6: Oracle bound guarantee
  • Theorem 3.8: Asymptotic normality of $\widehat{\alpha}_{\mathcal{S}}$
  • Corollary 3.9: Stability scaling with $K$