Table of Contents
Fetching ...

Universal Neural Functionals

Allan Zhou, Chelsea Finn, James Harrison

TL;DR

UNFs address the challenge of exploiting permutation symmetries in neural weight spaces by automatically constructing maximal $\mathcal{S}$-equivariant linear layers for arbitrary collections of weight tensors and stacking them into deep networks; they introduce a basis-based decomposition over weight subspaces to realize all equivariant maps, with an extension to multi-channel weight features and optional pooling for invariance. Empirically, UNFs improve learned-optimizer performance across MLP, CNN, RNN, and Transformer tasks and outperform prior weight-space models in RNN generalization prediction. The work culminates in an open-source library for building UNFs and demonstrates the practical value of symmetry-aware weight-space modeling for meta-learning and optimization.

Abstract

A challenging problem in many modern machine learning tasks is to process weight-space features, i.e., to transform or extract information from the weights and gradients of a neural network. Recent works have developed promising weight-space models that are equivariant to the permutation symmetries of simple feedforward networks. However, they are not applicable to general architectures, since the permutation symmetries of a weight space can be complicated by recurrence or residual connections. This work proposes an algorithm that automatically constructs permutation equivariant models, which we refer to as universal neural functionals (UNFs), for any weight space. Among other applications, we demonstrate how UNFs can be substituted into existing learned optimizer designs, and find promising improvements over prior methods when optimizing small image classifiers and language models. Our results suggest that learned optimizers can benefit from considering the (symmetry) structure of the weight space they optimize. We open-source our library for constructing UNFs at https://github.com/AllanYangZhou/universal_neural_functional.

Universal Neural Functionals

TL;DR

UNFs address the challenge of exploiting permutation symmetries in neural weight spaces by automatically constructing maximal -equivariant linear layers for arbitrary collections of weight tensors and stacking them into deep networks; they introduce a basis-based decomposition over weight subspaces to realize all equivariant maps, with an extension to multi-channel weight features and optional pooling for invariance. Empirically, UNFs improve learned-optimizer performance across MLP, CNN, RNN, and Transformer tasks and outperform prior weight-space models in RNN generalization prediction. The work culminates in an open-source library for building UNFs and demonstrates the practical value of symmetry-aware weight-space modeling for meta-learning and optimization.

Abstract

A challenging problem in many modern machine learning tasks is to process weight-space features, i.e., to transform or extract information from the weights and gradients of a neural network. Recent works have developed promising weight-space models that are equivariant to the permutation symmetries of simple feedforward networks. However, they are not applicable to general architectures, since the permutation symmetries of a weight space can be complicated by recurrence or residual connections. This work proposes an algorithm that automatically constructs permutation equivariant models, which we refer to as universal neural functionals (UNFs), for any weight space. Among other applications, we demonstrate how UNFs can be substituted into existing learned optimizer designs, and find promising improvements over prior methods when optimizing small image classifiers and language models. Our results suggest that learned optimizers can benefit from considering the (symmetry) structure of the weight space they optimize. We open-source our library for constructing UNFs at https://github.com/AllanYangZhou/universal_neural_functional.
Paper Structure (18 sections, 2 theorems, 30 equations, 3 figures, 1 table, 1 algorithm)

This paper contains 18 sections, 2 theorems, 30 equations, 3 figures, 1 table, 1 algorithm.

Key Result

Theorem 3.1

Let $\set{\mathcal{B}^{\ell m}}$ be bases for each $\mathbb{L}_{\mathcal{S}}\left( {\mathcal{W}^{(m)}},{\mathcal{W}^{(\ell)}} \right)$. Then the union of these bases (extended by Eq. eq:extension) is a basis for linear equivariant maps on $\mathcal{W}$. That is, is a basis for $\mathbb{L}_{\mathcal{S}}\left( {\mathcal{W}},{\mathcal{W}} \right)$.

Figures (3)

  • Figure 1: Illustration of the permutation symmetries in the weight space of a recurrent neural network (Example \ref{['exmp:rnn']}). Left: Each layer contains feedforward (ff) weights mapping between different layer's activations, and recurrent (rec) weights transforming activations over time. We can permute the hidden activations as illustrated without changing the final outputs $h^L_t$. Right: Permuting the hidden activations induces a permutation on the weights. Here, the rows and columns of the feedforward weights are permuted by $(\sigma_{\ell+1}, \sigma_{\ell})$, while the recurrent weights are permuted by $(\sigma_\ell, \sigma_\ell)$. Our algorithm automatically constructs permutation equivariant models for any collection of weight tensors given a description of its symmetries (Appendix \ref{['appendix:spec']}).
  • Figure 2: Training loss (negative log-likelihood) curves for different tasks and architectures using meta-learned optimizers. We implement learned optimizers with either universal neural functionals (UNFs), NFNs zhou2023permutation, or Deep Setszaheer2017deep. Deep Sets are the current standard choice for implementing learned optimizers. Note that NFN is identical to UNF in the MLP case, different for CNN case, and not applicable to RNNs or Transformers. All loss curves are smoothed and averaged over $5$ random initializations ($3$ for Transformer), with shaded regions showing standard error.
  • Figure 3: Number of parameters used by $f(\cdot)$ in each learned optimizer, for each task. Note that NFN and UNF are identical for the MLP task. This count does not include the other meta-learned scalars in Eq. \ref{['eq:lopt']}, which are $\alpha, \gamma_0, \beta$.

Theorems & Definitions (9)

  • Example 2.1: Multilayer perceptron
  • Example 2.2: Recurrent neural network
  • Example 2.3: Convolutional neural network
  • Theorem 3.1: navon2023equivariant
  • Definition 1
  • Example 3.1: $\mathcal{W}^{(m)}=\mathcal{W}^{(\ell)}=\mathbb{R}^{n_1 \times n_2}$
  • Example 3.2: $\mathcal{W}^{(m)}=\mathcal{W}^{(\ell)}=\mathbb{R}^{n_1 \times n_1}$
  • Example 3.3: $\mathcal{W}^{(m)}=\mathcal{W}^{(\ell)}=\mathbb{R}^{n_1 \times n_2}$
  • Theorem 3.2