Table of Contents
Fetching ...

Winning the Lottery by Preserving Network Training Dynamics with Concrete Ticket Search

Tanay Arora, Christof Teuscher

TL;DR

This work identifies a fundamental shortcoming of pruning-at-initialization methods that rely on first-order saliency, arguing that inter-weight dependencies and training dynamics are essential for effective sparsity. It introduces Concrete Ticket Search (CTS), a holistic, differentiable framework that uses a concrete relaxation and GradBalance to discover lottery-ticket subnetworks near initialization, with CTS-KL as a dynamics-preserving objective. Empirical results on CIFAR-10 and ImageNet show CTS yielding subnetworks that pass sanity checks and match or surpass Lottery Ticket Rewinding (LTR) at high sparsity, while requiring far less compute. Overall, CTS offers a scalable, near-initialization route to highly sparse, trainable subnetworks with strong performance and practical speedups, especially in the highly sparse regime.

Abstract

The Lottery Ticket Hypothesis asserts the existence of highly sparse, trainable subnetworks ('winning tickets') within dense, randomly initialized neural networks. However, state-of-the-art methods of drawing these tickets, like Lottery Ticket Rewinding (LTR), are computationally prohibitive, while more efficient saliency-based Pruning-at-Initialization (PaI) techniques suffer from a significant accuracy-sparsity trade-off and fail basic sanity checks. In this work, we argue that PaI's reliance on first-order saliency metrics, which ignore inter-weight dependencies, contributes substantially to this performance gap, especially in the sparse regime. To address this, we introduce Concrete Ticket Search (CTS), an algorithm that frames subnetwork discovery as a holistic combinatorial optimization problem. By leveraging a Concrete relaxation of the discrete search space and a novel gradient balancing scheme (GRADBALANCE) to control sparsity, CTS efficiently identifies high-performing subnetworks near initialization without requiring sensitive hyperparameter tuning. Motivated by recent works on lottery ticket training dynamics, we further propose a knowledge distillation-inspired family of pruning objectives, finding that minimizing the reverse Kullback-Leibler divergence between sparse and dense network outputs (CTS-KL) is particularly effective. Experiments on varying image classification tasks show that CTS produces subnetworks that robustly pass sanity checks and achieve accuracy comparable to or exceeding LTR, while requiring only a small fraction of the computation. For example, on ResNet-20 on CIFAR10, it reaches 99.3% sparsity with 74.0% accuracy in 7.9 minutes, while LTR attains the same sparsity with 68.3% accuracy in 95.2 minutes. CTS's subnetworks outperform saliency-based methods across all sparsities, but its advantage over LTR is most pronounced in the highly sparse regime.

Winning the Lottery by Preserving Network Training Dynamics with Concrete Ticket Search

TL;DR

This work identifies a fundamental shortcoming of pruning-at-initialization methods that rely on first-order saliency, arguing that inter-weight dependencies and training dynamics are essential for effective sparsity. It introduces Concrete Ticket Search (CTS), a holistic, differentiable framework that uses a concrete relaxation and GradBalance to discover lottery-ticket subnetworks near initialization, with CTS-KL as a dynamics-preserving objective. Empirical results on CIFAR-10 and ImageNet show CTS yielding subnetworks that pass sanity checks and match or surpass Lottery Ticket Rewinding (LTR) at high sparsity, while requiring far less compute. Overall, CTS offers a scalable, near-initialization route to highly sparse, trainable subnetworks with strong performance and practical speedups, especially in the highly sparse regime.

Abstract

The Lottery Ticket Hypothesis asserts the existence of highly sparse, trainable subnetworks ('winning tickets') within dense, randomly initialized neural networks. However, state-of-the-art methods of drawing these tickets, like Lottery Ticket Rewinding (LTR), are computationally prohibitive, while more efficient saliency-based Pruning-at-Initialization (PaI) techniques suffer from a significant accuracy-sparsity trade-off and fail basic sanity checks. In this work, we argue that PaI's reliance on first-order saliency metrics, which ignore inter-weight dependencies, contributes substantially to this performance gap, especially in the sparse regime. To address this, we introduce Concrete Ticket Search (CTS), an algorithm that frames subnetwork discovery as a holistic combinatorial optimization problem. By leveraging a Concrete relaxation of the discrete search space and a novel gradient balancing scheme (GRADBALANCE) to control sparsity, CTS efficiently identifies high-performing subnetworks near initialization without requiring sensitive hyperparameter tuning. Motivated by recent works on lottery ticket training dynamics, we further propose a knowledge distillation-inspired family of pruning objectives, finding that minimizing the reverse Kullback-Leibler divergence between sparse and dense network outputs (CTS-KL) is particularly effective. Experiments on varying image classification tasks show that CTS produces subnetworks that robustly pass sanity checks and achieve accuracy comparable to or exceeding LTR, while requiring only a small fraction of the computation. For example, on ResNet-20 on CIFAR10, it reaches 99.3% sparsity with 74.0% accuracy in 7.9 minutes, while LTR attains the same sparsity with 68.3% accuracy in 95.2 minutes. CTS's subnetworks outperform saliency-based methods across all sparsities, but its advantage over LTR is most pronounced in the highly sparse regime.

Paper Structure

This paper contains 26 sections, 17 equations, 10 figures, 2 tables, 3 algorithms.

Figures (10)

  • Figure 1: Performance on corresponding objective functions of saliency-based pruning in comparison to LTR and least-magnitude pruning. As described in sections \ref{['s:issuewithsaliency']} and \ref{['ss:choosingtheobjectivefunction']}, SNIP and GraSP can be seen as optimizing the objectives $\mathcal{R}_{\Delta \mathcal{L}}$ and $\mathcal{R}_{\lVert \nabla \rVert_2}$, respectively (cf. \ref{['saliencydefinition']}). Plotted values are calculated from VGG-16 on CIFAR-10 at initialization. Tickets drawn through saliency pruning rarely outperform baselines, even on the objectives they are designed to optimize.
  • Figure 2: Test-Accuracy with respect to sparsity of subnetworks of ResNet-20 (left) and VGG-16 (right) on CIFAR-10 produced by the proposed method, LTR, SNIP, GraSP, and SynFlow. Shaded intervals are confidence intervals, taken over three runs. We plot two objectives of CTS, $\mathcal{R}_{\text{KL}}$ and $\mathcal{R}_{\mathcal{L}}$. Subnetworks for CTS are generated according to Algorithm \ref{['subnetworkgenerationalgorithm']}, with the GradBalance gradient step.
  • Figure 3: Test-Accuracy with respect to sparsity of subnetworks of ResNet-20 (top) and VGG-16 (bottom) on CIFAR-10 produced by the proposed method over the six objective functions outlined in Section \ref{['ss:choosingtheobjectivefunction']}. Shaded intervals are confidence intervals, taken over three runs. Subnetworks for CTS are generated according to Algorithm \ref{['subnetworkgenerationalgorithm']}, with the GradBalance gradient step. Tickets drawn under the $\mathcal{R}_{-\lVert \nabla_{\theta} \mathcal{L} \rVert_2}$, $\mathcal{R}_{\text{feature}}$, and $\mathcal{R}_{\text{grad}}$ objectives cannot match performance of those drawn under $\mathcal{R}_{\text{KL}}$, $\mathcal{R}_{\Delta \mathcal{L}}$, and $\mathcal{R}_{\mathcal{L}}$.
  • Figure 4: Test-Accuracy with respect to sparsity of subnetworks of ResNet-20 on CIFAR10, under sanity check methods suggested in pruningatinitialization. These include weight reinitialization, score inversion, and shuffling layerwise. We test these for CTS$_{\text{KL}}$ (top) and $\overline{\text{CTS}_{\mathcal{L}}}$ (bottom), (cf. Section \ref{['sanitycheckingsection']}). Shaded intervals are confidence intervals, taken over three runs. Tickets drawn through CTS consistently outperform those with sanity-checking ablations, especially in the sparse regime.
  • Figure 5: Test-Accuracy with respect to sparsity of subnetworks of VGG-16 on CIFAR10, under sanity check methods suggested in pruningatinitialization. These include weight reinitialization, score inversion, and shuffling layerwise. We test these for CTS$_{\text{KL}}$ (top) and $\overline{\text{CTS}_{\mathcal{L}}}$ (bottom), (cf. Section \ref{['sanitycheckingsection']}). Shaded intervals are confidence intervals, taken over three runs. Tickets drawn through CTS consistently outperform those with sanity-checking ablations, especially in the sparse regime.
  • ...and 5 more figures