Stochastic Subnetwork Annealing: A Regularization Technique for Fine Tuning Pruned Subnetworks
Tim Whitaker, Darrell Whitley
TL;DR
The paper tackles the convergence challenges of pruning large neural networks by introducing Stochastic Subnetwork Annealing (SSA), which replaces fixed subnetworks with probabilistic masks and anneals the inclusion probabilities over time. By sampling masks from a probability matrix $P$ and gradually driving the subnetwork toward a deterministic topology, SSA enables smoother optimization and better generalization, especially at very high sparsity. Extensive ablations across CNNs and Vision Transformers show SSA outperforming one-shot and iterative pruning, and its integration into low-cost ensembles (Prune and Tune) yields improved generalization with reduced training compute. This approach offers a flexible, hardware-friendly regularization technique for fine-tuning pruned subnetworks and adapting to diverse architectures and datasets.
Abstract
Pruning methods have recently grown in popularity as an effective way to reduce the size and computational complexity of deep neural networks. Large numbers of parameters can be removed from trained models with little discernible loss in accuracy after a small number of continued training epochs. However, pruning too many parameters at once often causes an initial steep drop in accuracy which can undermine convergence quality. Iterative pruning approaches mitigate this by gradually removing a small number of parameters over multiple epochs. However, this can still lead to subnetworks that overfit local regions of the loss landscape. We introduce a novel and effective approach to tuning subnetworks through a regularization technique we call Stochastic Subnetwork Annealing. Instead of removing parameters in a discrete manner, we instead represent subnetworks with stochastic masks where each parameter has a probabilistic chance of being included or excluded on any given forward pass. We anneal these probabilities over time such that subnetwork structure slowly evolves as mask values become more deterministic, allowing for a smoother and more robust optimization of subnetworks at high levels of sparsity.
