Table of Contents
Fetching ...

Exploring higher-order neural network node interactions with total correlation

Thomas Kerby, Teresa White, Kevin Moon

TL;DR

HOIs are fundamental but scale as $O(2^n)$ with the number of variables, making global analysis intractable. Local CorEx uses PHATE to embed data, partitions it into clusters, and learns local latent factors by optimizing total correlation within each cluster, with latent factors satisfying $z = Wx + \epsilon$. It demonstrates effectiveness on synthetic data, the Communities dataset, and MNIST, and extends to neural-network interpretability by identifying local groups of hidden nodes that influence specific logits; dropout analysis reveals how redundancy grows in deeper layers. The approach provides a scalable, interpretable, unsupervised framework for discovering HOIs across heterogeneous data and neural representations.

Abstract

In domains such as ecological systems, collaborations, and the human brain the variables interact in complex ways. Yet accurately characterizing higher-order variable interactions (HOIs) is a difficult problem that is further exacerbated when the HOIs change across the data. To solve this problem we propose a new method called Local Correlation Explanation (CorEx) to capture HOIs at a local scale by first clustering data points based on their proximity on the data manifold. We then use a multivariate version of the mutual information called the total correlation, to construct a latent factor representation of the data within each cluster to learn the local HOIs. We use Local CorEx to explore HOIs in synthetic and real world data to extract hidden insights about the data structure. Lastly, we demonstrate Local CorEx's suitability to explore and interpret the inner workings of trained neural networks.

Exploring higher-order neural network node interactions with total correlation

TL;DR

HOIs are fundamental but scale as with the number of variables, making global analysis intractable. Local CorEx uses PHATE to embed data, partitions it into clusters, and learns local latent factors by optimizing total correlation within each cluster, with latent factors satisfying . It demonstrates effectiveness on synthetic data, the Communities dataset, and MNIST, and extends to neural-network interpretability by identifying local groups of hidden nodes that influence specific logits; dropout analysis reveals how redundancy grows in deeper layers. The approach provides a scalable, interpretable, unsupervised framework for discovering HOIs across heterogeneous data and neural representations.

Abstract

In domains such as ecological systems, collaborations, and the human brain the variables interact in complex ways. Yet accurately characterizing higher-order variable interactions (HOIs) is a difficult problem that is further exacerbated when the HOIs change across the data. To solve this problem we propose a new method called Local Correlation Explanation (CorEx) to capture HOIs at a local scale by first clustering data points based on their proximity on the data manifold. We then use a multivariate version of the mutual information called the total correlation, to construct a latent factor representation of the data within each cluster to learn the local HOIs. We use Local CorEx to explore HOIs in synthetic and real world data to extract hidden insights about the data structure. Lastly, we demonstrate Local CorEx's suitability to explore and interpret the inner workings of trained neural networks.
Paper Structure (18 sections, 2 equations, 22 figures, 4 tables, 2 algorithms)

This paper contains 18 sections, 2 equations, 22 figures, 4 tables, 2 algorithms.

Figures (22)

  • Figure 1: Overview of the Local CorEx algorithm. (a) PHATE visualization of the MNIST dataset. (b)$k$-means clustering is applied to the PHATE embedding to generate the local clusters. (c-d) A cluster is chosen and passed through Linear CorEx. (e) We visualize the mutual information between the learned CorEx latent factors and the original features to identify HOIs.
  • Figure 2: (a) Two-dimensional PHATE embedding of the communities and crime dataset misc_communities_and_crime_183 with the $10$ clusters resulting from Local CorEx plotted in different colors. (b) PHATE embedding with the % of people in urban areas colored. (c) PHATE embedding with the median household income colored. (d) PHATE embedding with the % of the pop age 65+ colored.
  • Figure 3: (Top left) A two-dimensional PHATE embedding of the MNIST dataset colored by data labels. (Top right) The same 2D PHATE embedding colored by the Local CorEx clusters. The clusters largely respect the boundaries between classes. (Bottom) The average pixel values found in each of the clusters. The outline of the majority digit present in each cluster is easily visible.
  • Figure 4: Visualizing the square-rooted mutual information between the first 15 Local CorEx factors trained on cluster 13 with the original features of the MNIST dataset.
  • Figure 5: Visualizing the effect of perturbing the average neural network hidden state representations of cluster 16 in the MNIST test dataset. (a) The plots are associated with perturbing the H1 representation. This first row is associated with the first Local CorEx factor and the second row is associated with the second Local CorEx factor. (b) Same as in (a) but for H2. For each group of plots, the leftmost column image is generated by subtracting the mutual information between the Local CorEx factor and the hidden nodes from the average representation. The second column image gives the average hidden state representation. The third column image is generated by adding the mutual information between the Local CorEx factor and the hidden nodes from the average representation. Finally, the rightmost column plots the mutual information between the Local CorEx factor and the model logits. This analysis gives us a visual intuition for what role the grouped hidden nodes play.
  • ...and 17 more figures