PruneX: A Hierarchical Communication-Efficient System for Distributed CNN Training with Structured Pruning
Alireza Olama, Andreas Lundell, Izzat El Hajj, Johan Lilius, Jerker Björkqvist
TL;DR
<3-5 sentence high-level summary> PruneX tackles the inter-node bandwidth bottleneck in distributed CNN training by co-designing structured pruning with cluster topology through a Hierarchical Structured ADMM (H-SADMM). It enforces node-level sparsity before inter-node synchronization, enabling physical buffer shrinkage and dense-kernel computation on compressed tensors, realized via a leader-follower architecture and a two-tier consensus. Empirical results on 64 GPUs show ~60% reduction in inter-node communication and 6.75x strong scaling, outperforming dense DDP and Top-K baselines on Puhti. The work demonstrates robust convergence and meaningful sparsity-accuracy trade-offs, with a clear path toward scaling to larger models and deeper hierarchies.
Abstract
Inter-node communication bandwidth increasingly constrains distributed training at scale on multi-node GPU clusters. While compact models are the ultimate deployment target, conventional pruning-aware distributed training systems typically fail to reduce communication overhead because unstructured sparsity cannot be efficiently exploited by highly optimized dense collective primitives. We present PruneX, a distributed data-parallel training system that co-designs pruning algorithms with cluster hierarchy to reduce inter-node bandwidth usage. PruneX introduces the Hierarchical Structured ADMM (H-SADMM) algorithm, which enforces node-level structured sparsity before inter-node synchronization, enabling dynamic buffer compaction that eliminates both zero-valued transmissions and indexing overhead. The system adopts a leader-follower execution model with separated intra-node and inter-node process groups, performing dense collectives on compacted tensors over bandwidth-limited links while confining full synchronization to high-bandwidth intra-node interconnects. Evaluation on ResNet architectures across 64 GPUs demonstrates that PruneX reduces inter-node communication volume by approximately 60% and achieves 6.75x strong scaling speedup, outperforming the dense baseline (5.81x) and Top-K gradient compression (3.71x) on the Puhti supercomputer at CSC - IT Center for Science (Finland).
