Table of Contents
Fetching ...

Converting MLPs into Polynomials in Closed Form

Nora Belrose, Alice Rigg

TL;DR

This work develops a principled, analytic framework to convert pretrained MLPs and GLUs into polynomial functions that globally minimize MSE under a maximum-entropy input model, enabling closed-form linear and quadratic approximants. On Gaussian-mixture approximations of MNIST, quadratic approximants explain $>94$–$95\%$ of the variance in network outputs, facilitating mechanistic interpretability via spectral decompositions and enabling SVD-based adversarial attacks that transfer to the original networks. The study also reveals training-time dynamics consistent with the distributional simplicity bias, showing an initial phase where networks appear simpler and a later phase where nonlinear (quadratic) structure dominates. These results provide a mathematically grounded lens for understanding network representations and suggest extensions to transformers and FFN interpretability using polynomial bases and spectral methods.

Abstract

Recent work has shown that purely quadratic functions can replace MLPs in transformers with no significant loss in performance, while enabling new methods of interpretability based on linear algebra. In this work, we theoretically derive closed-form least-squares optimal approximations of feedforward networks (multilayer perceptrons and gated linear units) using polynomial functions of arbitrary degree. When the $R^2$ is high, this allows us to interpret MLPs and GLUs by visualizing the eigendecomposition of the coefficients of their linear and quadratic approximants. We also show that these approximants can be used to create SVD-based adversarial examples. By tracing the $R^2$ of linear and quadratic approximants across training time, we find new evidence that networks start out simple, and get progressively more complex. Even at the end of training, however, our quadratic approximants explain over 95% of the variance in network outputs.

Converting MLPs into Polynomials in Closed Form

TL;DR

This work develops a principled, analytic framework to convert pretrained MLPs and GLUs into polynomial functions that globally minimize MSE under a maximum-entropy input model, enabling closed-form linear and quadratic approximants. On Gaussian-mixture approximations of MNIST, quadratic approximants explain of the variance in network outputs, facilitating mechanistic interpretability via spectral decompositions and enabling SVD-based adversarial attacks that transfer to the original networks. The study also reveals training-time dynamics consistent with the distributional simplicity bias, showing an initial phase where networks appear simpler and a later phase where nonlinear (quadratic) structure dominates. These results provide a mathematically grounded lens for understanding network representations and suggest extensions to transformers and FFN interpretability using polynomial bases and spectral methods.

Abstract

Recent work has shown that purely quadratic functions can replace MLPs in transformers with no significant loss in performance, while enabling new methods of interpretability based on linear algebra. In this work, we theoretically derive closed-form least-squares optimal approximations of feedforward networks (multilayer perceptrons and gated linear units) using polynomial functions of arbitrary degree. When the is high, this allows us to interpret MLPs and GLUs by visualizing the eigendecomposition of the coefficients of their linear and quadratic approximants. We also show that these approximants can be used to create SVD-based adversarial examples. By tracing the of linear and quadratic approximants across training time, we find new evidence that networks start out simple, and get progressively more complex. Even at the end of training, however, our quadratic approximants explain over 95% of the variance in network outputs.

Paper Structure

This paper contains 28 sections, 1 theorem, 37 equations, 5 figures.

Key Result

Theorem 3.1

Let $X, Y_1, \ldots, Y_n$ be $n+1$ jointly Gaussian random variables, and let $g : \mathbb{R} \rightarrow \mathbb{R}$ be a continuous, real-valued function. Then: where the coefficients $a_k$ can be computed analytically in the manner described below.

Figures (5)

  • Figure 1: Quadratic and linear features for class '3', over the course of training. Discernible '3' qualities arise from noise for both quadratic and linear features, aligning with the region where learning is happening. The linear 3 structure is most intuitively discernible at step (y), when FVU is minimal, before beginning to overfit. This can be interpreted as the MLP learning and relying on statistics of higher complexity than linear, especially if its accuracy continues to improve. The quadratic feature crystallizes later than linear, and predictably forms visual artifacts at the latest stages of training.
  • Figure 2: Fraction of variance unexplained for linear and quadratic approximants on a Gaussian mixture distribution imitation the MNIST training set. There is a sharp increase in the linear FVU between 500 and 1K training steps, while the quadratic FVU is roughly constant over the same time period.
  • Figure 3: KL divergence of linear and quadratic approximants from the network on which they were fit, evaluated on a Gaussian mixture distribution imitating the MNIST training set. The trend mirrors the FVU plot (\ref{['fig:fvu_comparison']}) except that the KL does not decrease in the first few hundred steps before increasing.
  • Figure 4: Top row: Adversarial '3' for different intervention strengths, ranging from one to ten SVD components ablated (left to right). Bottom row: Random examples of each digit with all ten SVD components ablated. Examples bear high resemblance to the original images, despite being unclassifiable by the MLP.
  • Figure 5: Steering with an adversarial mask. Strikingly, the original network's accuracy drops in perfect lockstep with that of its linear and quadratic approximants. Just four SVD components are needed to bring accuracy below 50%, and ablating the entire rowspace (only 10 out of 784 total dimensions) sufficient to make the MLP perform no better than random chance. This indicates that the approximations are appropriately capturing the generalization behavior of the MLP.

Theorems & Definitions (2)

  • Theorem 3.1: Master Theorem
  • proof