Learning Useful Representations of Recurrent Neural Network Weight Matrices
Vincent Herrmann, Francesco Faccio, Jürgen Schmidhuber
TL;DR
The paper investigates how to learn meaningful representations of RNN weight matrices by contrasting mechanistic and functionalist encoders, introducing six architectures with a focus on probing-based methods. It develops an emulation-based self-supervised framework that trains a weight-encoder to produce representations enabling an Emulator to mimic RNN behavior, and releases two model-zoo datasets (formal languages and tiled Sequential MNIST) to benchmark approaches. Empirical results show that interactive probing is particularly effective for complex, algorithmic tasks (formal languages), while probing methods also improve downstream property prediction on MNIST, outperforming purely supervised baselines in several settings. The work establishes a foundation for interpretable weight representations with potential applications in reinforcement learning, meta-learning, and large-scale sequence models, and highlights the trade-offs between probing strategies, invariances, and training stability.
Abstract
Recurrent Neural Networks (RNNs) are general-purpose parallel-sequential computers. The program of an RNN is its weight matrix. How to learn useful representations of RNN weights that facilitate RNN analysis as well as downstream tasks? While the mechanistic approach directly looks at some RNN's weights to predict its behavior, the functionalist approach analyzes its overall functionality-specifically, its input-output mapping. We consider several mechanistic approaches for RNN weights and adapt the permutation equivariant Deep Weight Space layer for RNNs. Our two novel functionalist approaches extract information from RNN weights by 'interrogating' the RNN through probing inputs. We develop a theoretical framework that demonstrates conditions under which the functionalist approach can generate rich representations that help determine RNN behavior. We release the first two 'model zoo' datasets for RNN weight representation learning. One consists of generative models of a class of formal languages, and the other one of classifiers of sequentially processed MNIST digits. With the help of an emulation-based self-supervised learning technique we compare and evaluate the different RNN weight encoding techniques on multiple downstream applications. On the most challenging one, namely predicting which exact task the RNN was trained on, functionalist approaches show clear superiority.
