Table of Contents
Fetching ...

Understanding Task Representations in Neural Networks via Bayesian Ablation

Andrew Nam, Declan Campbell, Thomas Griffiths, Jonathan Cohen, Sarah-Jane Leslie

TL;DR

This work tackles the challenge of interpreting latent task representations in neural networks by introducing a Bayesian ablation framework that treats ablation masks as a posterior P(m|t,c) conditioned on task and correctness. By applying this framework to the ISC model, the authors quantify properties of task representations—such as distributedness, manifold complexity, and differentiation—using entropy-based and information-theoretic metrics, and they assess polysemanticity via reverse inference measures I_n(t,m|c) = 1 − H(t|m,c)/H(t|c). The approach leverages an odds-ratio transformation to stabilize posterior inferences and employs metrics like Wasserstein distance and KL-divergence to compare task representations, revealing largely modular representations and strong alignments with conventional similarity measures when using the probabilistic AMD. The results provide a principled, causal lens for interpreting task representations, with implications for cognitive modeling and interpretability across neural architectures, while highlighting challenges in scaling and generalization to larger, more complex models.

Abstract

Neural networks are powerful tools for cognitive modeling due to their flexibility and emergent properties. However, interpreting their learned representations remains challenging due to their sub-symbolic semantics. In this work, we introduce a novel probabilistic framework for interpreting latent task representations in neural networks. Inspired by Bayesian inference, our approach defines a distribution over representational units to infer their causal contributions to task performance. Using ideas from information theory, we propose a suite of tools and metrics to illuminate key model properties, including representational distributedness, manifold complexity, and polysemanticity.

Understanding Task Representations in Neural Networks via Bayesian Ablation

TL;DR

This work tackles the challenge of interpreting latent task representations in neural networks by introducing a Bayesian ablation framework that treats ablation masks as a posterior P(m|t,c) conditioned on task and correctness. By applying this framework to the ISC model, the authors quantify properties of task representations—such as distributedness, manifold complexity, and differentiation—using entropy-based and information-theoretic metrics, and they assess polysemanticity via reverse inference measures I_n(t,m|c) = 1 − H(t|m,c)/H(t|c). The approach leverages an odds-ratio transformation to stabilize posterior inferences and employs metrics like Wasserstein distance and KL-divergence to compare task representations, revealing largely modular representations and strong alignments with conventional similarity measures when using the probabilistic AMD. The results provide a principled, causal lens for interpreting task representations, with implications for cognitive modeling and interpretability across neural architectures, while highlighting challenges in scaling and generalization to larger, more complex models.

Abstract

Neural networks are powerful tools for cognitive modeling due to their flexibility and emergent properties. However, interpreting their learned representations remains challenging due to their sub-symbolic semantics. In this work, we introduce a novel probabilistic framework for interpreting latent task representations in neural networks. Inspired by Bayesian inference, our approach defines a distribution over representational units to infer their causal contributions to task performance. Using ideas from information theory, we propose a suite of tools and metrics to illuminate key model properties, including representational distributedness, manifold complexity, and polysemanticity.

Paper Structure

This paper contains 20 sections, 17 equations, 7 figures.

Figures (7)

  • Figure 1: ISC Model. Number of units shown in parentheses. Sigmoid activation function is applied after each linear layer.
  • Figure 2: Task representation unit values $h_i(t)$ and their importance $1 - H(m_i \mid t,c)$.
  • Figure 3: Correlation between and task representation metrics and task acquisition order along different accuracy thresholds. Task acquisition order is defined as the order that a model's task accuracy first exceeds the specified threshold.
  • Figure 4: Percentage of normalized mutual information captured by the full AMD, $P(t \mid m)$, and by marginal unit distributions, $P(t \mid m_i)$. Each point represents a different model seed in 'Full' (10 total) or a combination of model seeds and representational units (240 total) in 'Marginal'.
  • Figure 5: Spearman correlation between similarity measures.
  • ...and 2 more figures