Neural Causal Abstractions
Kevin Xia, Elias Bareinboim
TL;DR
This work develops a framework for neural causal abstractions that compresses low-level data into high-level causal concepts while preserving interventional and counterfactual inferences across Pearl's causal hierarchy. It defines constructive abstraction functions $\tau$ using intervariable and intravariable clusters, enforces layer-specific consistency via the Abstract Invariance Condition, and connects abstraction with classical identification through cluster diagrams (C-DAGs) and neural identification (NCMs). By leveraging RNCMs, the approach learns task-aligned representations and enables $\tau$-identifiability of queries, with algorithms to construct abstractions and to solve abstract identification tasks. Experiments on nutrition data and colored MNIST demonstrate practical gains in identification, estimation, and sampling of causally valid distributions at coarser granularity and high dimensions. The framework thus provides a scalable, principled path to applying causal reasoning in real-world, high-dimensional domains.
Abstract
The abilities of humans to understand the world in terms of cause and effect relationships, as well as to compress information into abstract concepts, are two hallmark features of human intelligence. These two topics have been studied in tandem in the literature under the rubric of causal abstractions theory. In practice, it remains an open problem how to best leverage abstraction theory in real-world causal inference tasks, where the true mechanisms are unknown and only limited data is available. In this paper, we develop a new family of causal abstractions by clustering variables and their domains. This approach refines and generalizes previous notions of abstractions to better accommodate individual causal distributions that are spawned by Pearl's causal hierarchy. We show that such abstractions are learnable in practical settings through Neural Causal Models (Xia et al., 2021), enabling the use of the deep learning toolkit to solve various challenging causal inference tasks -- identification, estimation, sampling -- at different levels of granularity. Finally, we integrate these results with representation learning to create more flexible abstractions, moving these results closer to practical applications. Our experiments support the theory and illustrate how to scale causal inferences to high-dimensional settings involving image data.
