Table of Contents
Fetching ...

Equivariant Deep Weight Space Alignment

Aviv Navon, Aviv Shamsian, Ethan Fetaya, Gal Chechik, Nadav Dym, Haggai Maron

TL;DR

This work proves that weight alignment adheres to two fundamental symmetries and proposes a deep architecture that respects these symmetries, and provides a theoretical analysis of the approach and experimental results indicate that a feed-forward pass with Deep-Align produces better or equivalent alignments compared to those produced by current optimization algorithms.

Abstract

Permutation symmetries of deep networks make basic operations like model merging and similarity estimation challenging. In many cases, aligning the weights of the networks, i.e., finding optimal permutations between their weights, is necessary. Unfortunately, weight alignment is an NP-hard problem. Prior research has mainly focused on solving relaxed versions of the alignment problem, leading to either time-consuming methods or sub-optimal solutions. To accelerate the alignment process and improve its quality, we propose a novel framework aimed at learning to solve the weight alignment problem, which we name Deep-Align. To that end, we first prove that weight alignment adheres to two fundamental symmetries and then, propose a deep architecture that respects these symmetries. Notably, our framework does not require any labeled data. We provide a theoretical analysis of our approach and evaluate Deep-Align on several types of network architectures and learning setups. Our experimental results indicate that a feed-forward pass with Deep-Align produces better or equivalent alignments compared to those produced by current optimization algorithms. Additionally, our alignments can be used as an effective initialization for other methods, leading to improved solutions with a significant speedup in convergence.

Equivariant Deep Weight Space Alignment

TL;DR

This work proves that weight alignment adheres to two fundamental symmetries and proposes a deep architecture that respects these symmetries, and provides a theoretical analysis of the approach and experimental results indicate that a feed-forward pass with Deep-Align produces better or equivalent alignments compared to those produced by current optimization algorithms.

Abstract

Permutation symmetries of deep networks make basic operations like model merging and similarity estimation challenging. In many cases, aligning the weights of the networks, i.e., finding optimal permutations between their weights, is necessary. Unfortunately, weight alignment is an NP-hard problem. Prior research has mainly focused on solving relaxed versions of the alignment problem, leading to either time-consuming methods or sub-optimal solutions. To accelerate the alignment process and improve its quality, we propose a novel framework aimed at learning to solve the weight alignment problem, which we name Deep-Align. To that end, we first prove that weight alignment adheres to two fundamental symmetries and then, propose a deep architecture that respects these symmetries. Notably, our framework does not require any labeled data. We provide a theoretical analysis of our approach and evaluate Deep-Align on several types of network architectures and learning setups. Our experimental results indicate that a feed-forward pass with Deep-Align produces better or equivalent alignments compared to those produced by current optimization algorithms. Additionally, our alignments can be used as an effective initialization for other methods, leading to improved solutions with a significant speedup in convergence.
Paper Structure (19 sections, 8 theorems, 23 equations, 11 figures, 7 tables)

This paper contains 19 sections, 8 theorems, 23 equations, 11 figures, 7 tables.

Key Result

Proposition 3.1

The map $\mathcal{G}$ is $H$-equivariant, namely, for all $(v,v')\in \mathcal{V}^2_{\text{unique}}$ and $(g,g')\in H$,

Figures (11)

  • Figure 1: The equivariance structure of the alignment problem. The function $\mathcal{G}$ takes as input two weight space vectors $v,v'$ and outputs a sequence of permutation matrices that aligns them denoted $\mathcal{G}(v,v')$. In case we reorder the input using $(g,g')$ where $g=(P_1,P_2),g'=(P'_1,P'_2)$, the optimal alignment undergoes a transformation, namely $\mathcal{G}(g_{\#} v,g_{\#}' v')=g\cdot \mathcal{G}(v,v')\cdot g'^T$ .
  • Figure 2: Our architecture is a composition of four blocks: The first block, $F_{DWS}$ generates weight space embedding for both inputs. The second block $F_{\mathcal{V}\rightarrow \mathcal{A}}$ maps these to the activation spaces. The third block, $F_{Prod}$, generates square matrices by applying an outer product between the activation vector of one network to the activation vectors of the other network. Lastly, the fourth block, $F_{Proj}$ projects these square matrices on the (convex hull of) permutation matrices.
  • Figure 3: Merging image classifiers: the plots illustrate the values of the loss function used for training the input networks when evaluated on a line segment connecting $v$ and $g_{\#} v'$, where $g$ is the output of each method. Values are averaged over all test images and networks and 3 random seeds.
  • Figure 4: Sample size: The test barrier for aligning CIFAR10 CNN classifiers with a varying number of training examples.
  • Figure 5: Runtime comparison: Deep-Align is significantly more efficient at inference compared baseline methods.
  • ...and 6 more figures

Theorems & Definitions (14)

  • Proposition 3.1
  • Proposition 3.2
  • Proposition 4.1
  • Proposition 5.1
  • Proposition 5.2: Deep-Align is exact for perfect alignments
  • proof : Proof of Proposition \ref{['prop:equi']}
  • proof : Proof of Proposition \ref{['prop:inv']}
  • proof : Proof of generalization to other objectives
  • Proposition 3.1: Full formulation of Proposition \ref{['prop:activation_match']}
  • proof : Proof of Proposition \ref{['prop:activation_match2']}
  • ...and 4 more