Table of Contents
Fetching ...

Simultaneous linear connectivity of neural networks modulo permutation

Ekansh Sharma, Devin Kwok, Tom Denton, Daniel M. Roy, David Rolnick, Gintare Karolina Dziugaite

TL;DR

This work addresses permutation symmetries in neural networks that create non-convex loss landscapes by distinguishing three notions of linear mode connectivity modulo permutation. It refines prior results to show that existing evidence mainly supports weak LC, introduces simultaneous weak LC along SGD trajectories and IMP sequences, and presents initial evidence that strong LC may emerge as networks become very wide. The approach combines weight matching and activation matching to align networks, investigates how a single permutation can connect multiple related networks, and demonstrates that permuted masks from IMP can be transported across independently trained models. The findings highlight the potential for permutation-aware representations to render the loss landscape effectively convex in permutation-adjusted spaces under certain conditions, with practical implications for federated learning and model fusion, while also identifying algorithmic limitations that currently cap strong LC to very wide networks.

Abstract

Neural networks typically exhibit permutation symmetries which contribute to the non-convexity of the networks' loss landscapes, since linearly interpolating between two permuted versions of a trained network tends to encounter a high loss barrier. Recent work has argued that permutation symmetries are the only sources of non-convexity, meaning there are essentially no such barriers between trained networks if they are permuted appropriately. In this work, we refine these arguments into three distinct claims of increasing strength. We show that existing evidence only supports "weak linear connectivity"-that for each pair of networks belonging to a set of SGD solutions, there exist (multiple) permutations that linearly connect it with the other networks. In contrast, the claim "strong linear connectivity"-that for each network, there exists one permutation that simultaneously connects it with the other networks-is both intuitively and practically more desirable. This stronger claim would imply that the loss landscape is convex after accounting for permutation, and enable linear interpolation between three or more independently trained models without increased loss. In this work, we introduce an intermediate claim-that for certain sequences of networks, there exists one permutation that simultaneously aligns matching pairs of networks from these sequences. Specifically, we discover that a single permutation aligns sequences of iteratively trained as well as iteratively pruned networks, meaning that two networks exhibit low loss barriers at each step of their optimization and sparsification trajectories respectively. Finally, we provide the first evidence that strong linear connectivity may be possible under certain conditions, by showing that barriers decrease with increasing network width when interpolating among three networks.

Simultaneous linear connectivity of neural networks modulo permutation

TL;DR

This work addresses permutation symmetries in neural networks that create non-convex loss landscapes by distinguishing three notions of linear mode connectivity modulo permutation. It refines prior results to show that existing evidence mainly supports weak LC, introduces simultaneous weak LC along SGD trajectories and IMP sequences, and presents initial evidence that strong LC may emerge as networks become very wide. The approach combines weight matching and activation matching to align networks, investigates how a single permutation can connect multiple related networks, and demonstrates that permuted masks from IMP can be transported across independently trained models. The findings highlight the potential for permutation-aware representations to render the loss landscape effectively convex in permutation-adjusted spaces under certain conditions, with practical implications for federated learning and model fusion, while also identifying algorithmic limitations that currently cap strong LC to very wide networks.

Abstract

Neural networks typically exhibit permutation symmetries which contribute to the non-convexity of the networks' loss landscapes, since linearly interpolating between two permuted versions of a trained network tends to encounter a high loss barrier. Recent work has argued that permutation symmetries are the only sources of non-convexity, meaning there are essentially no such barriers between trained networks if they are permuted appropriately. In this work, we refine these arguments into three distinct claims of increasing strength. We show that existing evidence only supports "weak linear connectivity"-that for each pair of networks belonging to a set of SGD solutions, there exist (multiple) permutations that linearly connect it with the other networks. In contrast, the claim "strong linear connectivity"-that for each network, there exists one permutation that simultaneously connects it with the other networks-is both intuitively and practically more desirable. This stronger claim would imply that the loss landscape is convex after accounting for permutation, and enable linear interpolation between three or more independently trained models without increased loss. In this work, we introduce an intermediate claim-that for certain sequences of networks, there exists one permutation that simultaneously aligns matching pairs of networks from these sequences. Specifically, we discover that a single permutation aligns sequences of iteratively trained as well as iteratively pruned networks, meaning that two networks exhibit low loss barriers at each step of their optimization and sparsification trajectories respectively. Finally, we provide the first evidence that strong linear connectivity may be possible under certain conditions, by showing that barriers decrease with increasing network width when interpolating among three networks.
Paper Structure (54 sections, 2 theorems, 9 equations, 23 figures, 1 algorithm)

This paper contains 54 sections, 2 theorems, 9 equations, 23 figures, 1 algorithm.

Key Result

corollary thmcountercorollary

Let $\mathcal{F} \subseteq \mathbb{R}^k$ be a subset of neural networks that satisfies def:app-conjecture-entezari. Then, for all $\theta_1, \theta_2 \in \mathcal{F}$, there exists $\theta_3 \in \mathcal{F}$ such that $B(\theta_1, \theta_3)\approx 0$ and $B(\theta_2, \theta_3)\approx 0$.

Figures (23)

  • Figure 1: (Left) loss barrier ($y$-axis) between networks $A_t$ and permuted $B_t$ at various training times $t$ ($x$-axis). Permutation is computed at the end of training. The green line corresponds to applying permutations found via weight matching, blue corresponds to permutations found via activation matching, and gray corresponds to no permutation. (Right) loss barrier ($y$-axis) between $k$ "child networks" $A^k$ spawned from $A$ and permuted children $P_{\mathrm{end}}[B_t^k]$ of $B$ spawned at time $t$ ($x$-axis) from their respective "parent" networks (dashed brown line, $A_t^1 \leftrightarrow P_{\mathrm{end}}[B_t^1]$). $P_{\mathrm{end}}$ is computed by aligning parent networks $A$ and $B$ at the end of training (dotted grey line marks the barrier between the permuted parents, $A \leftrightarrow P_{\mathrm{end}}[B]$). These loss barriers are compared against the average loss barrier between independent child networks with the same parent (dot-dashed purple line, $A_t^1 \leftrightarrow A_t^2$). Each child is trained with a different minibatch order starting from the parent's weights at time $t$.
  • Figure 2: (Left) error barrier (y-axis) between permuted sparse IMP subnetworks derived from dense networks trained with different initializations and SGD noise. At each sparsity level (x-axis), pairs of sparse networks are aligned either (1) using a permutation $P^{(k)}$ computed directly on the sparse subnetworks at sparsity level $k$ (orange), and (2) using a permutation computed from the corresponding dense networks (blue). (Right) test accuracy of sparse networks using transported masks, at increasing levels of sparsity (x-axis indicates fraction of remaining weights). We use weight matching to compute the permutation.
  • Figure 3: Test of strong linear connectivity, comparing barriers (y-axis) of networks aligned directly or relative to a reference network (colors) via activation and weight matching (line styles). (Left) VGG-16 models of increasing width (x-axis). (Right) ResNet-20 models of increasing width (x-axis).
  • Figure 4: The error barrier (y-axis) between networks $A_t$ and permuted $B_t$ at different training checkpoints $t$ (x-axis). The brown line corresponds to applying a permutation $P_{\mathrm{end}}$ found at the end of training, which successfully eliminates the error barrier between the networks at nearly every epochs. The orange line corresponds to a permutation $P_t$ computed and applied at time $t$.
  • Figure 5: (Left) effect of magnitude pruning on performance of weight matching algorithm. Loss barrier (y-axis) for networks after training which are aligned with a permutation found via weight matching. Permutations are computed on checkpoints at different training epochs (colors) which are first sparsified (x-axis) via random (dashed) or magnitude (solid) pruning. (Right) loss barriers at the end of training between a pair of networks under "bottom-up" partial alignment, which concatenates $P_t$ from input up to layer $k$ with $P_{\mathrm{end}}$ for the remaining layers. The rewind time $t$ is the x-axis, and the split point $k$ is indicated by line color. Barriers for $P_t$ computed at time $t$ (x-axis) and at the end of training are included as baselines (dotted lines).
  • ...and 18 more figures

Theorems & Definitions (12)

  • definition thmcounterdefinition: Weak linear connectivity modulo permutation
  • definition thmcounterdefinition: Strong linear connectivity modulo permutation
  • definition thmcounterdefinition: Simultaneous weak linear connectivity modulo permutation
  • definition thmcounterdefinition
  • definition thmcounterdefinition: WLC mod P
  • definition thmcounterdefinition: SLC mod P
  • corollary thmcountercorollary: Piecewise linear connectivity
  • proof
  • proposition thmcounterproposition
  • proof
  • ...and 2 more