Table of Contents
Fetching ...

S2HPruner: Soft-to-Hard Distillation Bridges the Discretization Gap in Pruning

Weihao Lin, Shengji Tang, Chong Yu, Peng Ye, Tao Chen

TL;DR

S2HPruner addresses the discretization gap in differentiable mask pruning by introducing a one-stage soft-to-hard distillation framework. It jointly optimizes a soft pruning mask and distills a corresponding hard network under supervision from the soft network, employing decoupled bidirectional knowledge distillation to avoid degrading the soft model. Empirically, it delivers superior pruning performance across CIFAR-100, Tiny ImageNet, and ImageNet for CNNs and Transformers without post-training, and provides analyses of gradient components and gap behavior to justify effectiveness. The work highlights the importance of bridging the soft-hard discrepancy in pruning, with practical impact on producing higher-quality pruned architectures under strict resource budgets, while noting limitations and avenues for future extensions.

Abstract

Recently, differentiable mask pruning methods optimize the continuous relaxation architecture (soft network) as the proxy of the pruned discrete network (hard network) for superior sub-architecture search. However, due to the agnostic impact of the discretization process, the hard network struggles with the equivalent representational capacity as the soft network, namely discretization gap, which severely spoils the pruning performance. In this paper, we first investigate the discretization gap and propose a novel structural differentiable mask pruning framework named S2HPruner to bridge the discretization gap in a one-stage manner. In the training procedure, SH2Pruner forwards both the soft network and its corresponding hard network, then distills the hard network under the supervision of the soft network. To optimize the mask and prevent performance degradation, we propose a decoupled bidirectional knowledge distillation. It blocks the weight updating from the hard to the soft network while maintaining the gradient corresponding to the mask. Compared with existing pruning arts, S2HPruner achieves surpassing pruning performance without fine-tuning on comprehensive benchmarks, including CIFAR-100, Tiny ImageNet, and ImageNet with a variety of network architectures. Besides, investigation and analysis experiments explain the effectiveness of S2HPruner. Codes will be released soon.

S2HPruner: Soft-to-Hard Distillation Bridges the Discretization Gap in Pruning

TL;DR

S2HPruner addresses the discretization gap in differentiable mask pruning by introducing a one-stage soft-to-hard distillation framework. It jointly optimizes a soft pruning mask and distills a corresponding hard network under supervision from the soft network, employing decoupled bidirectional knowledge distillation to avoid degrading the soft model. Empirically, it delivers superior pruning performance across CIFAR-100, Tiny ImageNet, and ImageNet for CNNs and Transformers without post-training, and provides analyses of gradient components and gap behavior to justify effectiveness. The work highlights the importance of bridging the soft-hard discrepancy in pruning, with practical impact on producing higher-quality pruned architectures under strict resource budgets, while noting limitations and avenues for future extensions.

Abstract

Recently, differentiable mask pruning methods optimize the continuous relaxation architecture (soft network) as the proxy of the pruned discrete network (hard network) for superior sub-architecture search. However, due to the agnostic impact of the discretization process, the hard network struggles with the equivalent representational capacity as the soft network, namely discretization gap, which severely spoils the pruning performance. In this paper, we first investigate the discretization gap and propose a novel structural differentiable mask pruning framework named S2HPruner to bridge the discretization gap in a one-stage manner. In the training procedure, SH2Pruner forwards both the soft network and its corresponding hard network, then distills the hard network under the supervision of the soft network. To optimize the mask and prevent performance degradation, we propose a decoupled bidirectional knowledge distillation. It blocks the weight updating from the hard to the soft network while maintaining the gradient corresponding to the mask. Compared with existing pruning arts, S2HPruner achieves surpassing pruning performance without fine-tuning on comprehensive benchmarks, including CIFAR-100, Tiny ImageNet, and ImageNet with a variety of network architectures. Besides, investigation and analysis experiments explain the effectiveness of S2HPruner. Codes will be released soon.

Paper Structure

This paper contains 19 sections, 7 equations, 9 figures, 10 tables.

Figures (9)

  • Figure 1: Comparison of different typical pruning methods and illustration of discretization gap. The darker color represents the higher relative magnitude scale of weights or masks. $\odot$ denotes Hadamard product. For ease of demonstration, we use one layer to represent the entire network.
  • Figure 2: The proposed pruner's forward and backward flows, illustrated via an exemplary linear layer with parameters $\boldsymbol{\theta}$. The $\boldsymbol{u}$ are the additional learnable parameters normalized by softmax. The $\boldsymbol{w}$ denotes the relaxed mask. The estimated binary pruning mask is the $\hat{\boldsymbol{m}}$. The input is denoted by $\boldsymbol{i}$. The output of the soft and hard networks are the $\boldsymbol{s}$ and $\boldsymbol{h}$, respectively. The $\mathcal{L}$, $\mathcal{G}$, and $\mathcal{R}$ are the performance loss, gap measure, and resource regularization, respectively.
  • Figure 3: The trajectory of FLOPs and accuracy. We report the accuracy and FLOPs of the hard network during the training of different models, including (a) ResNet-50 (b) MobileNetV3 (c) WideResNet28-10 (d) ViT (e) Swin Transformer on CIFAR-100.
  • Figure 4: The architectures of networks, including (a) ResNet-50 (b) MobileNetV3 (c) WideResNet28-10 (d) ViT (e) Swin Transformer, pruned via our proposed method on CIFAR-100. The target FLOPs is set to 15%.
  • Figure 5: The detailed channel variation of ResNet-50 on CIFAR-100 during training. The target FLOPs is set to 15%. The horizontal axis represents the training iterations. The vertical axis represents the output channel number.
  • ...and 4 more figures