Table of Contents
Fetching ...

Efficient Discovery of Approximate Causal Abstractions via Neural Mechanism Sparsification

Amir Asiaee

TL;DR

Treating a trained network as a deterministic SCM, this work derives an Interventional Risk objective whose second-order expansion yields closed-form criteria for replacing units with constants or folding them into neighbors, which is validated via interchange interventions.

Abstract

Neural networks are hypothesized to implement interpretable causal mechanisms, yet verifying this requires finding a causal abstraction -- a simpler, high-level Structural Causal Model (SCM) faithful to the network under interventions. Discovering such abstractions is hard: it typically demands brute-force interchange interventions or retraining. We reframe the problem by viewing structured pruning as a search over approximate abstractions. Treating a trained network as a deterministic SCM, we derive an Interventional Risk objective whose second-order expansion yields closed-form criteria for replacing units with constants or folding them into neighbors. Under uniform curvature, our score reduces to activation variance, recovering variance-based pruning as a special case while clarifying when it fails. The resulting procedure efficiently extracts sparse, intervention-faithful abstractions from pretrained networks, which we validate via interchange interventions.

Efficient Discovery of Approximate Causal Abstractions via Neural Mechanism Sparsification

TL;DR

Treating a trained network as a deterministic SCM, this work derives an Interventional Risk objective whose second-order expansion yields closed-form criteria for replacing units with constants or folding them into neighbors, which is validated via interchange interventions.

Abstract

Neural networks are hypothesized to implement interpretable causal mechanisms, yet verifying this requires finding a causal abstraction -- a simpler, high-level Structural Causal Model (SCM) faithful to the network under interventions. Discovering such abstractions is hard: it typically demands brute-force interchange interventions or retraining. We reframe the problem by viewing structured pruning as a search over approximate abstractions. Treating a trained network as a deterministic SCM, we derive an Interventional Risk objective whose second-order expansion yields closed-form criteria for replacing units with constants or folding them into neighbors. Under uniform curvature, our score reduces to activation variance, recovering variance-based pruning as a special case while clarifying when it fails. The resulting procedure efficiently extracts sparse, intervention-faithful abstractions from pretrained networks, which we validate via interchange interventions.
Paper Structure (57 sections, 8 theorems, 22 equations, 6 figures, 3 tables)

This paper contains 57 sections, 8 theorems, 22 equations, 6 figures, 3 tables.

Key Result

proposition 1

Under $\mathrm{do}(a^{(\ell)}_j := c)$, let $W' := W^{(\ell+1)}_{:,\backslash j}$ and $b' := b^{(\ell+1)} + c\,W^{(\ell+1)}_{:,j}$. Then $W^{(\ell+1)} a^{(\ell)} + b^{(\ell+1)}|_{a^{(\ell)}_j=c} = W' a^{(\ell)}_{\backslash j} + b'$.

Figures (6)

  • Figure 1: Overview. (a) Causal abstraction as commutativity: a high-level SCM $M_H$ abstracts a low-level model $M_L$ if intervening at the high level (via $I$) and intervening at the low level (via $\omega(I)$) yield consistent results under the state map $\tau$. (b) Our discovery and verification pipeline: given a dense network, we compute per-unit scores $s_j$ via a second-order surrogate, select low-score units for mechanism replacement, compile the result into a smaller network $M_H$, and verify faithfulness via interchange interventions (IIA).
  • Figure 2: Compilation of mechanism replacements. (a) Constant replacement: setting $a_j^{(\ell)} := c$ severs incoming edges; the effect on downstream units is absorbed into the bias via $b^{(\ell+1)} \leftarrow b^{(\ell+1)} + c \cdot W_{:,j}^{(\ell+1)}$, then column $j$ is deleted. (b) Affine replacement: replacing $a_j$ with a linear combination of retained units $\{a_k\}_{k \in P}$ redistributes outgoing weights to those units before deletion.
  • Figure 3: Constructive abstraction discovery
  • Figure 4: MNIST fidelity and accuracy vs. keep size. Fidelity is measured via interchange interventions on retained penultimate coordinates.
  • Figure 5: Scaling invariance stress test (keep $= 256$). Absolute metrics under exact function-preserving scaling reparameterizations (mean with 95% CI over 10 seeds). Logit-MSE (and cwvar) are stable (Jaccard $= 1$) and maintain strong-swap fidelity, while VBP is unstable and substantially less faithful under interventions.
  • ...and 1 more figures

Theorems & Definitions (8)

  • proposition 1: Bias folding for constant replacement
  • proposition 2: Weight folding for affine replacement
  • lemma 1: Exact transformation
  • proposition 3: Quadratic proxy
  • proposition 4: Optimal constant
  • proposition 5: Variance ranking as a special case
  • proposition 6: Affine fit
  • proposition 7: Additive scores