Table of Contents
Fetching ...

Learning Useful Representations of Recurrent Neural Network Weight Matrices

Vincent Herrmann, Francesco Faccio, Jürgen Schmidhuber

TL;DR

The paper investigates how to learn meaningful representations of RNN weight matrices by contrasting mechanistic and functionalist encoders, introducing six architectures with a focus on probing-based methods. It develops an emulation-based self-supervised framework that trains a weight-encoder to produce representations enabling an Emulator to mimic RNN behavior, and releases two model-zoo datasets (formal languages and tiled Sequential MNIST) to benchmark approaches. Empirical results show that interactive probing is particularly effective for complex, algorithmic tasks (formal languages), while probing methods also improve downstream property prediction on MNIST, outperforming purely supervised baselines in several settings. The work establishes a foundation for interpretable weight representations with potential applications in reinforcement learning, meta-learning, and large-scale sequence models, and highlights the trade-offs between probing strategies, invariances, and training stability.

Abstract

Recurrent Neural Networks (RNNs) are general-purpose parallel-sequential computers. The program of an RNN is its weight matrix. How to learn useful representations of RNN weights that facilitate RNN analysis as well as downstream tasks? While the mechanistic approach directly looks at some RNN's weights to predict its behavior, the functionalist approach analyzes its overall functionality-specifically, its input-output mapping. We consider several mechanistic approaches for RNN weights and adapt the permutation equivariant Deep Weight Space layer for RNNs. Our two novel functionalist approaches extract information from RNN weights by 'interrogating' the RNN through probing inputs. We develop a theoretical framework that demonstrates conditions under which the functionalist approach can generate rich representations that help determine RNN behavior. We release the first two 'model zoo' datasets for RNN weight representation learning. One consists of generative models of a class of formal languages, and the other one of classifiers of sequentially processed MNIST digits. With the help of an emulation-based self-supervised learning technique we compare and evaluate the different RNN weight encoding techniques on multiple downstream applications. On the most challenging one, namely predicting which exact task the RNN was trained on, functionalist approaches show clear superiority.

Learning Useful Representations of Recurrent Neural Network Weight Matrices

TL;DR

The paper investigates how to learn meaningful representations of RNN weight matrices by contrasting mechanistic and functionalist encoders, introducing six architectures with a focus on probing-based methods. It develops an emulation-based self-supervised framework that trains a weight-encoder to produce representations enabling an Emulator to mimic RNN behavior, and releases two model-zoo datasets (formal languages and tiled Sequential MNIST) to benchmark approaches. Empirical results show that interactive probing is particularly effective for complex, algorithmic tasks (formal languages), while probing methods also improve downstream property prediction on MNIST, outperforming purely supervised baselines in several settings. The work establishes a foundation for interpretable weight representations with potential applications in reinforcement learning, meta-learning, and large-scale sequence models, and highlights the trade-offs between probing strategies, invariances, and training stability.

Abstract

Recurrent Neural Networks (RNNs) are general-purpose parallel-sequential computers. The program of an RNN is its weight matrix. How to learn useful representations of RNN weights that facilitate RNN analysis as well as downstream tasks? While the mechanistic approach directly looks at some RNN's weights to predict its behavior, the functionalist approach analyzes its overall functionality-specifically, its input-output mapping. We consider several mechanistic approaches for RNN weights and adapt the permutation equivariant Deep Weight Space layer for RNNs. Our two novel functionalist approaches extract information from RNN weights by 'interrogating' the RNN through probing inputs. We develop a theoretical framework that demonstrates conditions under which the functionalist approach can generate rich representations that help determine RNN behavior. We release the first two 'model zoo' datasets for RNN weight representation learning. One consists of generative models of a class of formal languages, and the other one of classifiers of sequentially processed MNIST digits. With the help of an emulation-based self-supervised learning technique we compare and evaluate the different RNN weight encoding techniques on multiple downstream applications. On the most challenging one, namely predicting which exact task the RNN was trained on, functionalist approaches show clear superiority.
Paper Structure (28 sections, 7 theorems, 5 equations, 24 figures, 11 tables)

This paper contains 28 sections, 7 theorems, 5 equations, 24 figures, 11 tables.

Key Result

Proposition 3.0

Any function $f_C$ from a set $D$ can be identified by an interrogator through at most $|D| - 1$ interactions.

Figures (24)

  • Figure 1: RNN weight encoder architectures taking weights $\theta$ as input and producing a representation $E(\theta)$. The two groups of four weight matrices symbolize the four gates of two LSTM layers. The last matrix represents the output projection.
  • Figure 1: Properties of the different RNN weight encoder architectures. $N$ is the number of hidden neurons in $f_\theta$.
  • Figure 2: Comparison of non-interactive and interactive procing encoders.
  • Figure 3: Interrogator $I_D$ has access to a set $D$ of functions and interacts with function $f_C$, which it has to identify.
  • Figure 4: Emulation-based self-supervised training. The encoder $E$ is trained to generate embeddings of $\theta$ that allow $A$ to emulate $f_\theta$.
  • ...and 19 more figures

Theorems & Definitions (11)

  • Proposition 3.0
  • Proposition 3.0
  • Proposition 3.0
  • Lemma 1.1
  • proof
  • Proposition 1.1
  • proof
  • Proposition 1.1
  • proof
  • Proposition 1.1
  • ...and 1 more