Table of Contents
Fetching ...

Causal Abstractions of Neural Networks

Atticus Geiger, Hanson Lu, Thomas Icard, Christopher Potts

TL;DR

This work develops a formal causal-abstraction framework to explain neural network behavior by aligning low-level representations with high-level causal model variables and validating causality through interchange interventions, formalized within a constructive abstraction setup. Using MQNLI, which is grounded in a tree-structured natural-logic model, the authors search for alignments (e.g., mapping internal BERT representations to high-level nodes like $S_1=X+Y$ and $S_2=S_1+W$) and experimentally verify causal equivalence of network computations with the high-level model. The case study shows that a BERT-based model partially realizes the natural-logic causal structure, while an LSTM baseline does not, providing evidence that BERT encodes the compositional structure necessary for MQNLI. Overall, the paper presents a principled, testable alternative to probes and gradient-based attributions for explaining neural computation and demonstrates a scalable approach to evaluating abstract causal hypotheses in NLP.

Abstract

Structural analysis methods (e.g., probing and feature attribution) are increasingly important tools for neural network analysis. We propose a new structural analysis method grounded in a formal theory of causal abstraction that provides rich characterizations of model-internal representations and their roles in input/output behavior. In this method, neural representations are aligned with variables in interpretable causal models, and then interchange interventions are used to experimentally verify that the neural representations have the causal properties of their aligned variables. We apply this method in a case study to analyze neural models trained on Multiply Quantified Natural Language Inference (MQNLI) corpus, a highly complex NLI dataset that was constructed with a tree-structured natural logic causal model. We discover that a BERT-based model with state-of-the-art performance successfully realizes parts of the natural logic model's causal structure, whereas a simpler baseline model fails to show any such structure, demonstrating that BERT representations encode the compositional structure of MQNLI.

Causal Abstractions of Neural Networks

TL;DR

This work develops a formal causal-abstraction framework to explain neural network behavior by aligning low-level representations with high-level causal model variables and validating causality through interchange interventions, formalized within a constructive abstraction setup. Using MQNLI, which is grounded in a tree-structured natural-logic model, the authors search for alignments (e.g., mapping internal BERT representations to high-level nodes like and ) and experimentally verify causal equivalence of network computations with the high-level model. The case study shows that a BERT-based model partially realizes the natural-logic causal structure, while an LSTM baseline does not, providing evidence that BERT encodes the compositional structure necessary for MQNLI. Overall, the paper presents a principled, testable alternative to probes and gradient-based attributions for explaining neural computation and demonstrates a scalable approach to evaluating abstract causal hypotheses in NLP.

Abstract

Structural analysis methods (e.g., probing and feature attribution) are increasingly important tools for neural network analysis. We propose a new structural analysis method grounded in a formal theory of causal abstraction that provides rich characterizations of model-internal representations and their roles in input/output behavior. In this method, neural representations are aligned with variables in interpretable causal models, and then interchange interventions are used to experimentally verify that the neural representations have the causal properties of their aligned variables. We apply this method in a case study to analyze neural models trained on Multiply Quantified Natural Language Inference (MQNLI) corpus, a highly complex NLI dataset that was constructed with a tree-structured natural logic causal model. We discover that a BERT-based model with state-of-the-art performance successfully realizes parts of the natural logic model's causal structure, whereas a simpler baseline model fails to show any such structure, demonstrating that BERT representations encode the compositional structure of MQNLI.

Paper Structure

This paper contains 51 sections, 25 equations, 9 figures, 2 tables.

Figures (9)

  • Figure 1: Our motivating example where we hypothesis that a symbolic computation $C_+$ is a causal abstraction of a neural network $N_+$ under a particular alignment (top). We can experimentally confirm this hypothesis by conducting an interchange intervention on both the network and the computation with every pair of inputs and evaluating whether the intervened network and intervened computation have the same counterfactual output behavior. We schematically depict an interchange intervention on the network $N_+$ (bottom left) and the computation $C_+$ (bottom right) with the base input $(1,2,3)$ and the source input $(4,5,6)$. Observe that the output of the intervened neural network matches the output of the intervened symbolic computation, so we have success for this pair of inputs.
  • Figure 2: The natural logic causal model (top), MQNLI examples (left) and MQNLI results (right).
  • Figure 3: A BERT-based NLI model (left) aligned with the natural logic causal model $C_{\emph{NatLog}}^{\text{NP}_\text{Obj}}$ (right), where the fourth vector representation above the $\text{Adj}_\text{Obj}^{P}$ token in the network is aligned with $\text{NP}_\text{Obj}$, the variable representing the relation between the object noun phrases. When analyzing a sample of 1000 examples, we found a subset of 383 where $C_{\emph{NatLog}}^{\text{NP}_\text{Obj}}$ is an abstraction of $N_{\emph{NLI}}$ under this alignment.
  • Figure 4: Interchange intervention and probing results for the $\text{NP}_\text{Obj}$ position. Vertical axes denote layers of BERT and horizontal axes denote the token position of hidden representations. The intervention success rates reported here are calculated based on intervention experiments with a change in the output label. Clique sizes are reported as % of 1000 examples.
  • Figure 5: Full probing and interchange intervention results for high-level nodes $\text{NP}_\text{Obj}$, $\text{N}_\text{Obj}$, $\text{Adj}_\text{Obj}$, VP, V, and Adv. Vertical axes denote BERT layers and horizontal axes denote the token position of hidden representations. Intervention success rates are based on experiments with a change in the output label. Clique sizes are reported as a percentage of all examples.
  • ...and 4 more figures

Theorems & Definitions (7)

  • Definition F.1
  • Definition F.2
  • Definition F.3
  • Definition F.4
  • Definition F.5
  • Definition F.6
  • Definition F.7: Constructive $\tau$-abstraction