Table of Contents
Fetching ...

Finding Stable Subnetworks at Initialization with Dataset Distillation

Luke McDermott, Rahul Parhi

TL;DR

This work addresses finding stable subnetworks at initialization by leveraging dataset distillation to create a compact synthetic training set. The authors introduce Distilled Pruning, which uses distilled data in the inner loop of iterative magnitude pruning to obtain stable subnetworks from unstable dense initializations, and show that these synthetic subnetworks can match the performance of traditional lottery tickets on CIFAR-10 with ResNet-18 using far fewer training points. They further demonstrate that combining distilled pruning with IMP yields lottery tickets at high sparsities, including notable gains on ImageNet subsets, and provide supporting analyses—linear mode connectivity, loss-landscape visualizations, and Hessian diagnostics—that distilled subnetworks exhibit greater stability and smoother loss surfaces than conventional IMP. The findings suggest that distilled data can guide pruning dynamics from initialization, enabling efficient, high-sparsity lottery tickets and informing future data-centric pruning strategies. Practical impact includes potential reductions in training data needs and computational costs for discovering and training sparse subnetworks."

Abstract

Recent works have shown that Dataset Distillation, the process for summarizing the training data, can be leveraged to accelerate the training of deep learning models. However, its impact on training dynamics, particularly in neural network pruning, remains largely unexplored. In our work, we use distilled data in the inner loop of iterative magnitude pruning to produce sparse, trainable subnetworks at initialization -- more commonly known as lottery tickets. While using 150x less training points, our algorithm matches the performance of traditional lottery ticket rewinding on ResNet-18 & CIFAR-10. Previous work highlights that lottery tickets can be found when the dense initialization is stable to SGD noise (i.e. training across different ordering of the data converges to the same minima). We extend this discovery, demonstrating that stable subnetworks can exist even within an unstable dense initialization. In our linear mode connectivity studies, we find that pruning with distilled data discards parameters that contribute to the sharpness of the loss landscape. Lastly, we show that by first generating a stable sparsity mask at initialization, we can find lottery tickets at significantly higher sparsities than traditional iterative magnitude pruning.

Finding Stable Subnetworks at Initialization with Dataset Distillation

TL;DR

This work addresses finding stable subnetworks at initialization by leveraging dataset distillation to create a compact synthetic training set. The authors introduce Distilled Pruning, which uses distilled data in the inner loop of iterative magnitude pruning to obtain stable subnetworks from unstable dense initializations, and show that these synthetic subnetworks can match the performance of traditional lottery tickets on CIFAR-10 with ResNet-18 using far fewer training points. They further demonstrate that combining distilled pruning with IMP yields lottery tickets at high sparsities, including notable gains on ImageNet subsets, and provide supporting analyses—linear mode connectivity, loss-landscape visualizations, and Hessian diagnostics—that distilled subnetworks exhibit greater stability and smoother loss surfaces than conventional IMP. The findings suggest that distilled data can guide pruning dynamics from initialization, enabling efficient, high-sparsity lottery tickets and informing future data-centric pruning strategies. Practical impact includes potential reductions in training data needs and computational costs for discovering and training sparse subnetworks."

Abstract

Recent works have shown that Dataset Distillation, the process for summarizing the training data, can be leveraged to accelerate the training of deep learning models. However, its impact on training dynamics, particularly in neural network pruning, remains largely unexplored. In our work, we use distilled data in the inner loop of iterative magnitude pruning to produce sparse, trainable subnetworks at initialization -- more commonly known as lottery tickets. While using 150x less training points, our algorithm matches the performance of traditional lottery ticket rewinding on ResNet-18 & CIFAR-10. Previous work highlights that lottery tickets can be found when the dense initialization is stable to SGD noise (i.e. training across different ordering of the data converges to the same minima). We extend this discovery, demonstrating that stable subnetworks can exist even within an unstable dense initialization. In our linear mode connectivity studies, we find that pruning with distilled data discards parameters that contribute to the sharpness of the loss landscape. Lastly, we show that by first generating a stable sparsity mask at initialization, we can find lottery tickets at significantly higher sparsities than traditional iterative magnitude pruning.

Paper Structure

This paper contains 11 sections, 1 equation, 7 figures.

Figures (7)

  • Figure 1: Distilled Pruning Algorithm Diagram
  • Figure 2: Performance of Distilled Pruning vs Traditional IMP on ResNet-18 & CIFAR-10. The distilled dataset consisted of 10 images per class. Error bars are plotted as we average across 4 seeds. The plot on the right measures the amount of data points used in training to find a sparsity mask at x sparsity. Note that in IMP we are not matching the dense performance since we rewind back to initialization for both methods --- not to an early point in training. Lottery tickets do not exist here past $\approx$ 40% sparsity.
  • Figure 3: Comparison of the stability of synthetic vs. IMP subnetworks at initialization on CIFAR-10. We show how the loss increases as you interpolate the weights between two trained models. We measure this for subnetworks of different sparsities. The left column is reserved for subnetworks found via distilled data, and the middle column is for subnetworks found with real data. We aggregate all the information in each row for a better comparison. The dark lines in the 3D plots represents the pruning iteration we used for the combined plot; the dense model is iteration 0.
  • Figure 4: Comparison of the stability of synthetic vs. IMP subnetworks at initialization on ImageNet-10 and ResNet-10. An increased loss across interpolation implies instability / trained networks landing in different minima.
  • Figure 5: Loss Landscape visualization around the neigbhorhood defined by trained models on different seeds for ConvNet-3 and CIFAR-10.
  • ...and 2 more figures