Table of Contents
Fetching ...

Visualizing Neural Network Imagination

Nevan Wichers, Victor Tao, Riccardo Volpato, Fazl Barez

TL;DR

This paper tackles interpretability by visualizing the hidden intermediate states a neural network represents while predicting a final environment state. It introduces an encoder–RNN–decoder architecture applied to Conway's Game of Life and augments it with autoencoder regularization and adversarial decoder training to produce GoL-like intermediate reconstructions from hidden representations. A thresholded pixel-matching metric assesses how well intermediate decodings align with ground-truth GoL states, and experiments reveal that architectural choices and training objectives influence interpretability, with autoencoder and adversarial training generally benefiting the results. The approach shows promise for revealing network 'imagination' in a controlled setting, though it encounters scalability limitations to more complex domains such as chess, highlighting both its potential and its current bounds.

Abstract

In certain situations, neural networks will represent environment states in their hidden activations. Our goal is to visualize what environment states the networks are representing. We experiment with a recurrent neural network (RNN) architecture with a decoder network at the end. After training, we apply the decoder to the intermediate representations of the network to visualize what they represent. We define a quantitative interpretability metric and use it to demonstrate that hidden states can be highly interpretable on a simple task. We also develop autoencoder and adversarial techniques and show that benefit interpretability.

Visualizing Neural Network Imagination

TL;DR

This paper tackles interpretability by visualizing the hidden intermediate states a neural network represents while predicting a final environment state. It introduces an encoder–RNN–decoder architecture applied to Conway's Game of Life and augments it with autoencoder regularization and adversarial decoder training to produce GoL-like intermediate reconstructions from hidden representations. A thresholded pixel-matching metric assesses how well intermediate decodings align with ground-truth GoL states, and experiments reveal that architectural choices and training objectives influence interpretability, with autoencoder and adversarial training generally benefiting the results. The approach shows promise for revealing network 'imagination' in a controlled setting, though it encounters scalability limitations to more complex domains such as chess, highlighting both its potential and its current bounds.

Abstract

In certain situations, neural networks will represent environment states in their hidden activations. Our goal is to visualize what environment states the networks are representing. We experiment with a recurrent neural network (RNN) architecture with a decoder network at the end. After training, we apply the decoder to the intermediate representations of the network to visualize what they represent. We define a quantitative interpretability metric and use it to demonstrate that hidden states can be highly interpretable on a simple task. We also develop autoencoder and adversarial techniques and show that benefit interpretability.
Paper Structure (16 sections, 1 equation, 5 figures, 3 tables)

This paper contains 16 sections, 1 equation, 5 figures, 3 tables.

Figures (5)

  • Figure 1: Our setup with 3 GoL and 3 model timesteps. Left: Our setup during training. The RNN layer blocks share weights. The Encoder, RNN layers and Decoder are all convolutional. Middle: Our autoencoder training setup. Right: Our setup during inference. Our hypothesis is that the states that are generated when applying the decoder to the intermediate timesteps are similar to the intermediate GoL states.
  • Figure 2: An example with 3 GoL timesteps and 3 model timesteps. The metric gets a value of .5 because the 3rd predicted state matches the ground truth, but the 2nd predicted state does not. The 1st and 4th states are ignored because the network is trained on them.
  • Figure 3: The metric gets a value of .75 because both predicted states match the same ground truth state. So the 3rd predicted state gets a half score.
  • Figure 4: An example with 2 GoL timesteps and 3 model timesteps. The metric gets a value of 1.0. The 2nd model state matches the intermediate GoL state. The 3rd model state matches the 3rd GoL state, but this is ignored because the model was trained to predict the 3rd GoL state. The denominator in the metric is 1 because $min(2, 3-2) = 1$
  • Figure 5: An example with 2 GoL timesteps and 3 model timesteps. The metric gets a value of 1.5. The 2nd model state matches the intermediate GoL state. The 3rd model state also matches the intermediate GoL state, so the metric gives .5 score for this state. Note that we ignore the last GoL state and not the last model state, to this match is a valid one