Table of Contents
Fetching ...

GradMetaNet: An Equivariant Architecture for Learning on Gradients

Yoav Gelberg, Yam Eitan, Aviv Navon, Aviv Shamsian, Theo, Putterman, Michael Bronstein, Haggai Maron

TL;DR

GradMetaNet addresses learning on neural gradients by building an equivariant architecture that preserves neuron permutation symmetries and processes gradient sets via a rank-1 gradient decomposition. It provides universality guarantees for continuous, gradient-based functions and demonstrates practical gains in curvature estimation, learned optimization, and INR editing on MLPs and transformers. By leveraging gradient sets and gradient-symmetric processing, GradMetaNet outperforms weight-space baselines and scales to larger models across several tasks, offering a principled framework for gradient-based learning with broad applicability in optimization, pruning, and editing. This work advances gradient-aware modeling by integrating symmetry-aware design with set-based gradient statistics to capture local geometry more effectively than traditional single-gradient or averaged-gradient approaches.

Abstract

Gradients of neural networks encode valuable information for optimization, editing, and analysis of models. Therefore, practitioners often treat gradients as inputs to task-specific algorithms, e.g. for pruning or optimization. Recent works explore learning algorithms that operate directly on gradients but use architectures that are not specifically designed for gradient processing, limiting their applicability. In this paper, we present a principled approach for designing architectures that process gradients. Our approach is guided by three principles: (1) equivariant design that preserves neuron permutation symmetries, (2) processing sets of gradients across multiple data points to capture curvature information, and (3) efficient gradient representation through rank-1 decomposition. Based on these principles, we introduce GradMetaNet, a novel architecture for learning on gradients, constructed from simple equivariant blocks. We prove universality results for GradMetaNet, and show that previous approaches cannot approximate natural gradient-based functions that GradMetaNet can. We then demonstrate GradMetaNet's effectiveness on a diverse set of gradient-based tasks on MLPs and transformers, such as learned optimization, INR editing, and estimating loss landscape curvature.

GradMetaNet: An Equivariant Architecture for Learning on Gradients

TL;DR

GradMetaNet addresses learning on neural gradients by building an equivariant architecture that preserves neuron permutation symmetries and processes gradient sets via a rank-1 gradient decomposition. It provides universality guarantees for continuous, gradient-based functions and demonstrates practical gains in curvature estimation, learned optimization, and INR editing on MLPs and transformers. By leveraging gradient sets and gradient-symmetric processing, GradMetaNet outperforms weight-space baselines and scales to larger models across several tasks, offering a principled framework for gradient-based learning with broad applicability in optimization, pruning, and editing. This work advances gradient-aware modeling by integrating symmetry-aware design with set-based gradient statistics to capture local geometry more effectively than traditional single-gradient or averaged-gradient approaches.

Abstract

Gradients of neural networks encode valuable information for optimization, editing, and analysis of models. Therefore, practitioners often treat gradients as inputs to task-specific algorithms, e.g. for pruning or optimization. Recent works explore learning algorithms that operate directly on gradients but use architectures that are not specifically designed for gradient processing, limiting their applicability. In this paper, we present a principled approach for designing architectures that process gradients. Our approach is guided by three principles: (1) equivariant design that preserves neuron permutation symmetries, (2) processing sets of gradients across multiple data points to capture curvature information, and (3) efficient gradient representation through rank-1 decomposition. Based on these principles, we introduce GradMetaNet, a novel architecture for learning on gradients, constructed from simple equivariant blocks. We prove universality results for GradMetaNet, and show that previous approaches cannot approximate natural gradient-based functions that GradMetaNet can. We then demonstrate GradMetaNet's effectiveness on a diverse set of gradient-based tasks on MLPs and transformers, such as learned optimization, INR editing, and estimating loss landscape curvature.

Paper Structure

This paper contains 76 sections, 19 theorems, 107 equations, 9 figures, 12 tables.

Key Result

Proposition 6.1

Let $\{\nabla_{({\bm{x}}, {\bm{y}})}\}_{({\bm{x}}, {\bm{y}}) \in {{\mathcal{B}}}}$ be gradients computed on on a set of datapoints ${{\mathcal{B}}} \subseteq {\mathcal{D}}$. There exist functions--such as natural gradient approximations or pruning saliency scores--that cannot be reconstructed from t

Figures (9)

  • Figure 1: We propose GradMetaNet, a novel architecture that processes sets of gradients and can learn to compute gradient adaptations, parameter edits, or scalar values such as curvature information or influence functions.
  • Figure 2: Gradient information on a batch of datapoints in different tensor representations. In \ref{['subfig:standard_gradient_data']}, a stack of the weight-shaped gradients, one for each datapoint. In \ref{['subfig:decomposed_gradient_data']}, a stack of the rank-1 gradient decompositions. In \ref{['subfig:reduced_gradient_data']}, the gradient of the average loss on the batch. All of these tensors are naturally computed when backpropagating the loss on the batch.
  • Figure 3: Fisher information as a second-order approximation to the loss.
  • Figure 4: The action of $G = \textcolor{MediumBlue}{S_{d_1}} \times \cdots \times \textcolor{Coral}{S_{d_{L-1}}}$ on parameter space performs simultaneous permutation of rows and columns of consecutive weight matrices. In contrast, $G$'s action on the decomposed gradient space permutes the neuron space of each hidden layer independently.
  • Figure 5: GradMetaNet pipeline: gradients are decomposed into rank-1 factors and positional encoding is applied. The input is then transformed by a stack of ${L_{{\boldsymbol{\Gamma}}_b}}$ equivariant interactions-across-sets layers. ${L_{\mathrm{Pool}}}$ pools these representations into ${\boldsymbol{\Gamma}}[f]$, removing the batch dimension. Then a stack of ${L_{{\boldsymbol{\Gamma}}}}$ layers updates this representation, and ${L_{\mathrm{Prod}}}$ maps the result back to ${\boldsymbol{\Theta}}$.
  • ...and 4 more figures

Theorems & Definitions (37)

  • Proposition 6.1
  • Theorem 6.2
  • Corollary 6.3
  • Definition E.1
  • Proposition E.2
  • proof
  • Definition E.3: Natural gradient map
  • Definition E.4: OBD/OBS pruning saliency maps
  • Proposition E.5
  • proof
  • ...and 27 more