Pruning is Optimal for Learning Sparse Features in High-Dimensions
Nuri Mert Vural, Murat A. Erdogdu
TL;DR
This work explains why pruning can yield optimal feature learning in high-dimensional sparse regimes by proving that pruned neural networks trained with gradient descent can achieve CSQ-aligned sample complexity for multi-index models with soft sparsity. It develops a pruning-based dimension-reduction procedure that, combined with carefully designed gradient steps and even-odd Hermite decomposition, recovers the sparse directions with high probability. The authors derive sparsity-aware CSQ lower bounds and show that pruning attains these bounds for both single-index (k^*≥1) and multi-index (k^*=2) cases, while basis-independent methods cannot. The results imply practical benefits for feature learning in high dimensions and provide a theoretical separation between pruning-based and standard gradient methods, with implications for understanding generalization and representation learning in sparse regimes.
Abstract
While it is commonly observed in practice that pruning networks to a certain level of sparsity can improve the quality of the features, a theoretical explanation of this phenomenon remains elusive. In this work, we investigate this by demonstrating that a broad class of statistical models can be optimally learned using pruned neural networks trained with gradient descent, in high-dimensions. We consider learning both single-index and multi-index models of the form $y = σ^*(\boldsymbol{V}^{\top} \boldsymbol{x}) + ε$, where $σ^*$ is a degree-$p$ polynomial, and $\boldsymbol{V} \in \mathbbm{R}^{d \times r}$ with $r \ll d$, is the matrix containing relevant model directions. We assume that $\boldsymbol{V}$ satisfies a certain $\ell_q$-sparsity condition for matrices and show that pruning neural networks proportional to the sparsity level of $\boldsymbol{V}$ improves their sample complexity compared to unpruned networks. Furthermore, we establish Correlational Statistical Query (CSQ) lower bounds in this setting, which take the sparsity level of $\boldsymbol{V}$ into account. We show that if the sparsity level of $\boldsymbol{V}$ exceeds a certain threshold, training pruned networks with a gradient descent algorithm achieves the sample complexity suggested by the CSQ lower bound. In the same scenario, however, our results imply that basis-independent methods such as models trained via standard gradient descent initialized with rotationally invariant random weights can provably achieve only suboptimal sample complexity.
