Table of Contents
Fetching ...

MINAR: Mechanistic Interpretability for Neural Algorithmic Reasoning

Jesse He, Helen Jenne, Max Vargas, Davis Brown, Gal Mishne, Yusu Wang, Henry Kvinge

TL;DR

This work introduces Mechanistic Interpretability for Neural Algorithmic Reasoning (MINAR), an efficient circuit discovery toolbox that adapts attribution patching methods from mechanistic interpretability to the GNN setting and shows that MINAR recovers faithful neuron-level circuits from GNNs trained on algorithmic tasks.

Abstract

The recent field of neural algorithmic reasoning (NAR) studies the ability of graph neural networks (GNNs) to emulate classical algorithms like Bellman-Ford, a phenomenon known as algorithmic alignment. At the same time, recent advances in large language models (LLMs) have spawned the study of mechanistic interpretability, which aims to identify granular model components like circuits that perform specific computations. In this work, we introduce Mechanistic Interpretability for Neural Algorithmic Reasoning (MINAR), an efficient circuit discovery toolbox that adapts attribution patching methods from mechanistic interpretability to the GNN setting. We show through two case studies that MINAR recovers faithful neuron-level circuits from GNNs trained on algorithmic tasks. Our study sheds new light on the process of circuit formation and pruning during training, as well as giving new insight into how GNNs trained to perform multiple tasks in parallel reuse circuit components for related tasks. Our code is available at https://github.com/pnnl/MINAR.

MINAR: Mechanistic Interpretability for Neural Algorithmic Reasoning

TL;DR

This work introduces Mechanistic Interpretability for Neural Algorithmic Reasoning (MINAR), an efficient circuit discovery toolbox that adapts attribution patching methods from mechanistic interpretability to the GNN setting and shows that MINAR recovers faithful neuron-level circuits from GNNs trained on algorithmic tasks.

Abstract

The recent field of neural algorithmic reasoning (NAR) studies the ability of graph neural networks (GNNs) to emulate classical algorithms like Bellman-Ford, a phenomenon known as algorithmic alignment. At the same time, recent advances in large language models (LLMs) have spawned the study of mechanistic interpretability, which aims to identify granular model components like circuits that perform specific computations. In this work, we introduce Mechanistic Interpretability for Neural Algorithmic Reasoning (MINAR), an efficient circuit discovery toolbox that adapts attribution patching methods from mechanistic interpretability to the GNN setting. We show through two case studies that MINAR recovers faithful neuron-level circuits from GNNs trained on algorithmic tasks. Our study sheds new light on the process of circuit formation and pruning during training, as well as giving new insight into how GNNs trained to perform multiple tasks in parallel reuse circuit components for related tasks. Our code is available at https://github.com/pnnl/MINAR.
Paper Structure (29 sections, 21 equations, 26 figures, 1 table, 1 algorithm)

This paper contains 29 sections, 21 equations, 26 figures, 1 table, 1 algorithm.

Figures (26)

  • Figure 1: MSE training loss, multiplicative test loss, and $L_1$ regularization term for Bellman-Ford MinAggGNN.
  • Figure 2: Identified circuit in the Bellman-Ford MinAggGNN (left) and parallel Bellman-Ford and BFS MinAggGNN (right). Nodes are individual MinAggGNN neurons. Input and output neurons are colored white, $f_{\mathop{\mathrm{Agg}}\nolimits}$ neurons are colored blue, and $f_{\mathop{\mathrm{Up}}\nolimits}$ neurons are colored orange. Circuit edges are colored by corresponding model weights.
  • Figure 3: Multiplicative test loss, MSE test loss, and $L_1$ regularization terms for the Bellman-Ford network and circuit.
  • Figure 4: MSE training loss, multiplicative test loss, reachability accuracy, and $L_1$ regularization term for parallel Bellman-Ford and BFS MinAggGNN.
  • Figure 5: Characterization score of circuits computed using Weight, WeightGrad, EAP, and EAP-IG ($m=20$) for SALSA-CLRS GNN for $K \in \{10, 25, 50, 100, 250, 500, 1000, 1500, 2000, 2500, 3000\}$.
  • ...and 21 more figures

Theorems & Definitions (1)

  • Definition 3.1: Model Computation Graph