Table of Contents
Fetching ...

Weight-based Decomposition: A Case for Bilinear MLPs

Michael T. Pearce, Thomas Dooms, Alice Rigg

TL;DR

This work tackles interpretability of MLP-based components in transformers by focusing on bilinear MLPs, which remove the gating nonlinearity while preserving expressiveness through a third-order tensor. The authors introduce a weight-based eigenvector decomposition that rewrites the bilinear computations as a set of sparsely interacting features, fully equivalent to the original model. They demonstrate interpretable top eigenvectors on MNIST and provide preliminary evidence of interpretable language-model features in Tiny Stories, and show that pretrained models (TinyLlama-1.1B) can be finetuned to bilinear variants with competitive loss. Regularization via latent noise improves interpretability and can enhance generalization, while noting limitations like polysemanticity and scalability. This approach offers a potential bridge between weights and interpretable features, with implications for mechanistic interpretability and model debugging.

Abstract

Gated Linear Units (GLUs) have become a common building block in modern foundation models. Bilinear layers drop the non-linearity in the "gate" but still have comparable performance to other GLUs. An attractive quality of bilinear layers is that they can be fully expressed in terms of a third-order tensor and linear operations. Leveraging this, we develop a method to decompose the bilinear tensor into a set of sparsely interacting eigenvectors that show promising interpretability properties in preliminary experiments for shallow image classifiers (MNIST) and small language models (Tiny Stories). Since the decomposition is fully equivalent to the model's original computations, bilinear layers may be an interpretability-friendly architecture that helps connect features to the model weights. Application of our method may not be limited to pretrained bilinear models since we find that language models such as TinyLlama-1.1B can be finetuned into bilinear variants.

Weight-based Decomposition: A Case for Bilinear MLPs

TL;DR

This work tackles interpretability of MLP-based components in transformers by focusing on bilinear MLPs, which remove the gating nonlinearity while preserving expressiveness through a third-order tensor. The authors introduce a weight-based eigenvector decomposition that rewrites the bilinear computations as a set of sparsely interacting features, fully equivalent to the original model. They demonstrate interpretable top eigenvectors on MNIST and provide preliminary evidence of interpretable language-model features in Tiny Stories, and show that pretrained models (TinyLlama-1.1B) can be finetuned to bilinear variants with competitive loss. Regularization via latent noise improves interpretability and can enhance generalization, while noting limitations like polysemanticity and scalability. This approach offers a potential bridge between weights and interpretable features, with implications for mechanistic interpretability and model debugging.

Abstract

Gated Linear Units (GLUs) have become a common building block in modern foundation models. Bilinear layers drop the non-linearity in the "gate" but still have comparable performance to other GLUs. An attractive quality of bilinear layers is that they can be fully expressed in terms of a third-order tensor and linear operations. Leveraging this, we develop a method to decompose the bilinear tensor into a set of sparsely interacting eigenvectors that show promising interpretability properties in preliminary experiments for shallow image classifiers (MNIST) and small language models (Tiny Stories). Since the decomposition is fully equivalent to the model's original computations, bilinear layers may be an interpretability-friendly architecture that helps connect features to the model weights. Application of our method may not be limited to pretrained bilinear models since we find that language models such as TinyLlama-1.1B can be finetuned into bilinear variants.
Paper Structure (33 sections, 1 theorem, 6 equations, 13 figures, 7 tables)

This paper contains 33 sections, 1 theorem, 6 equations, 13 figures, 7 tables.

Key Result

Theorem 3.1

If $Q: \mathbb{R}^d \to \mathbb{R}^d$ is a real, symmetric matrix, then there exists an orthonormal basis of $\mathbb{R}^d$ consisting of eigenvectors of $Q$. Each eigenvalue is real. That is, $Q=P^T\Lambda P$, where $\Lambda$ is a real diagonal matrix, and $P$ is a real orthogonal matrix.

Figures (13)

  • Figure 1: Illustration of the eigenvalue decomposition. A) For a single layer, we start with a vector $\mathbf{u}$ in the output space of the third-order tensor $\mathcal{B}$. Their dot product gives a symmetric interaction matrix $Q$. Eigendecomposition of $Q$ gives an orthonormal set of eigenvectors $\mathbf{v}_i$. In this basis, the interactions are sparse since an eigenvector only interacts with itself. B) For a fully bilinear model, we start with an unembed vector $\mathbf{u}$ (e.g., for the digit "3") and repeatedly apply the single-layer eigendecomposition. Each layer $k$ eigenvector $\mathbf{v}_{\{i_N, \dots, i_k\}}$ acts as the output vector that determines a set of layer $k$-1 eigenvectors $\mathbf{v}_{\{i_N, \dots, i_k, i_{k-1}\}}$. Using the embedding weights, the layer-1 eigenvectors can be transformed into input features. C) Schematic of a 2-layer model after full decomposition showing the tree-like computational graph, with a branching factor of 2 (instead of d_model) for simplicity. The model is decompiled in the sense that interactions have been sparsified and made explicit through the graph. Only the layer-1 eigenvectors are needed to get the initial activations from the post-embedding inputs. The eigenvalue magnitudes parameterize the importance of the edges.
  • Figure 2: Eigenvectors for single-layer MNIST and Fashion-MNIST models. A) Top positive eigenvectors by mean activation for the first seven output classes. To the right of each eigenvector are the 3 inputs with the highest activations. Ordering by activation favors eigenvectors with more "complete" images that overlap more with the inputs. Often the top activating eigenvector has the top or second top eigenvalue. B) Because of the square dot product in an eigenvector's activation, $\lambda_i (\mathbf{v}_i^T \mathbf{x})^2$, its overall sign is arbitrary, so we choose the sign so that the positive component (blue) is similar to the top inputs. If an input separately has a high overlap with the positive (blue) or negative (red) component then the activation can be large, but a high overlap with both components cancels out. This behavior is similar to an XOR. C) The top three positive and negative eigenvectors for digit 1. The top positive eigenvalue has strong activations for digit 4, but these likely cancel out with activations for the top negative eigenvalue which detects horizontal lines in the center of an image.
  • Figure 3: The validation accuracy for an MNIST model truncated to the top $k$ eigenvectors by eigenvalue magnitude.
  • Figure 4: Cosine similarities for eigenvectors that best match an eigenvector from a $d=300$ model, averaged over 5 initializations per model size. The eigenvectors for digit 3 show that moderate similarities can correspond to visually similar features. Results shown are only for positive eigenvectors and similarities are computed for eigenvectors transformed to the input basis using the embedding weights.
  • Figure 5: Comparison of different regularizations. Latent noise regularization adds dense Gaussian noise to the inputs of each layer, including the embedding and unembedding layers. We use a noise with a standard deviation of 0.33 the standard deviation of the input. Latent noise is found to remove the large values in the periphery of the image and improve validation accuracy. Weight decay (parameter=0.5) reduces the fine-scale noisiness in the images and reduces performance. The paper results are from models with weight decay and latent noise.
  • ...and 8 more figures

Theorems & Definitions (1)

  • Theorem 3.1: Spectral Theorem