Table of Contents
Fetching ...

Linear Mode Connectivity in Sparse Neural Networks

Luke McDermott, Daniel Cummings

TL;DR

This work investigates sparse neural networks trained with distilled (synthetic) data, exploring how dataset distillation interacts with pruning to yield stable initialization-level subnetworks. It introduces distilled pruning, where a sparsity mask is found by training on a synthetic dataset and rewinding to initialization, and demonstrates that such synthetic subnetworks can exhibit linear mode connectivity and robustness to SGD noise. Using Information-intensive Dataset Condensation (IDC), the authors show synthetic masks can match traditional Iterative Magnitude Pruning (IMP) performance while using up to 150x fewer training points, highlighting potential efficiency gains. The findings suggest that synthetic data-driven pruning can produce stable, high-quality lottery-ticket-like subnetworks, motivating further exploration of stable subnetwork search at initialization.

Abstract

With the rise in interest of sparse neural networks, we study how neural network pruning with synthetic data leads to sparse networks with unique training properties. We find that distilled data, a synthetic summarization of the real data, paired with Iterative Magnitude Pruning (IMP) unveils a new class of sparse networks that are more stable to SGD noise on the real data, than either the dense model, or subnetworks found with real data in IMP. That is, synthetically chosen subnetworks often train to the same minima, or exhibit linear mode connectivity. We study this through linear interpolation, loss landscape visualizations, and measuring the diagonal of the hessian. While dataset distillation as a field is still young, we find that these properties lead to synthetic subnetworks matching the performance of traditional IMP with up to 150x less training points in settings where distilled data applies.

Linear Mode Connectivity in Sparse Neural Networks

TL;DR

This work investigates sparse neural networks trained with distilled (synthetic) data, exploring how dataset distillation interacts with pruning to yield stable initialization-level subnetworks. It introduces distilled pruning, where a sparsity mask is found by training on a synthetic dataset and rewinding to initialization, and demonstrates that such synthetic subnetworks can exhibit linear mode connectivity and robustness to SGD noise. Using Information-intensive Dataset Condensation (IDC), the authors show synthetic masks can match traditional Iterative Magnitude Pruning (IMP) performance while using up to 150x fewer training points, highlighting potential efficiency gains. The findings suggest that synthetic data-driven pruning can produce stable, high-quality lottery-ticket-like subnetworks, motivating further exploration of stable subnetwork search at initialization.

Abstract

With the rise in interest of sparse neural networks, we study how neural network pruning with synthetic data leads to sparse networks with unique training properties. We find that distilled data, a synthetic summarization of the real data, paired with Iterative Magnitude Pruning (IMP) unveils a new class of sparse networks that are more stable to SGD noise on the real data, than either the dense model, or subnetworks found with real data in IMP. That is, synthetically chosen subnetworks often train to the same minima, or exhibit linear mode connectivity. We study this through linear interpolation, loss landscape visualizations, and measuring the diagonal of the hessian. While dataset distillation as a field is still young, we find that these properties lead to synthetic subnetworks matching the performance of traditional IMP with up to 150x less training points in settings where distilled data applies.
Paper Structure (11 sections, 4 equations, 6 figures)

This paper contains 11 sections, 4 equations, 6 figures.

Figures (6)

  • Figure 1: Comparison of the stability of synthetic vs. IMP subnetworks at initialization on CIFAR-10. We show how the loss increases as you interpolate the weights between two trained models. We measure this for subnetworks of different sparsities. The left column is reserved for subnetworks found via distilled data, and the middle column is for subnetworks found with real data. The dark lines in the 3D plots represents the pruning iteration we used for the combined plot; the dense model is iteration 0.
  • Figure 2: Comparison of the stability of synthetic vs. IMP subnetworks at initialization on ImageNet-10 and ResNet-10. An increased loss across interpolation implies instability / trained networks landing in different minima.
  • Figure 3: Loss Landscape visualization around the neigbhorhood defined by trained models on different seeds for ConvNet-3 and CIFAR-10.
  • Figure 4: On the left, we draw a parallel between intra-model representation learning with the deep information bottleneck. On the right, we show an example of using dataset distillation to compress a training set to 1 image per class for classification tasks.
  • Figure 5: Performance of Distilled Pruning vs Traditional IMP on ResNet-18 & CIFAR-10. The distilled dataset consisted of 10 images per class. Error bars are plotted as we average across 4 seeds. The plot on the right measures the amount of data points used in training to find a sparsity mask at x sparsity. Note that in IMP we are not finding "lottery tickets" since we rewind back to initialization for both methods, not to an early point in training.
  • ...and 1 more figures