Table of Contents
Fetching ...

Automatic Channel Pruning for Multi-Head Attention

Eunho Lee, Youngbae Hwang

TL;DR

This work tackles the challenge of pruning vision transformers without collapsing the expressive power of multi-head attention. It introduces APMA, a multi-head aware automatic pruning framework that uses similarity-based weights, head-wise pruning synchronization, a reweight module to offset information loss, and a data-driven initialization for linear attention, enabling pruning of both original and linear attention mechanisms. Applied to FlattenTransformer, APMA yields state-of-the-art efficiency-accuracy trade-offs across MACs, with notable throughput gains, and demonstrates robust ablations showing the contribution of each component. The approach advances practical deployment of efficient transformers on large-scale vision benchmarks like ImageNet-1K and sets the stage for extension to detection and segmentation tasks.

Abstract

Despite the strong performance of Transformers, their quadratic computation complexity presents challenges in applying them to vision tasks. Automatic pruning is one of effective methods for reducing computation complexity without heuristic approaches. However, directly applying it to multi-head attention is not straightforward due to channel misalignment. In this paper, we propose an automatic channel pruning method to take into account the multi-head attention mechanism. First, we incorporate channel similarity-based weights into the pruning indicator to preserve more informative channels in each head. Then, we adjust pruning indicator to enforce removal of channels in equal proportions across all heads, preventing the channel misalignment. We also add a reweight module to compensate for information loss resulting from channel removal, and an effective initialization step for pruning indicator based on difference of attention between original structure and each channel. Our proposed method can be used to not only original attention, but also linear attention, which is more efficient as linear complexity with respect to the number of tokens. On ImageNet-1K, applying our pruning method to the FLattenTransformer, which includes both attention mechanisms, shows outperformed accuracy for several MACs compared with previous state-of-the-art efficient models and pruned methods. Code will be available soon.

Automatic Channel Pruning for Multi-Head Attention

TL;DR

This work tackles the challenge of pruning vision transformers without collapsing the expressive power of multi-head attention. It introduces APMA, a multi-head aware automatic pruning framework that uses similarity-based weights, head-wise pruning synchronization, a reweight module to offset information loss, and a data-driven initialization for linear attention, enabling pruning of both original and linear attention mechanisms. Applied to FlattenTransformer, APMA yields state-of-the-art efficiency-accuracy trade-offs across MACs, with notable throughput gains, and demonstrates robust ablations showing the contribution of each component. The approach advances practical deployment of efficient transformers on large-scale vision benchmarks like ImageNet-1K and sets the stage for extension to detection and segmentation tasks.

Abstract

Despite the strong performance of Transformers, their quadratic computation complexity presents challenges in applying them to vision tasks. Automatic pruning is one of effective methods for reducing computation complexity without heuristic approaches. However, directly applying it to multi-head attention is not straightforward due to channel misalignment. In this paper, we propose an automatic channel pruning method to take into account the multi-head attention mechanism. First, we incorporate channel similarity-based weights into the pruning indicator to preserve more informative channels in each head. Then, we adjust pruning indicator to enforce removal of channels in equal proportions across all heads, preventing the channel misalignment. We also add a reweight module to compensate for information loss resulting from channel removal, and an effective initialization step for pruning indicator based on difference of attention between original structure and each channel. Our proposed method can be used to not only original attention, but also linear attention, which is more efficient as linear complexity with respect to the number of tokens. On ImageNet-1K, applying our pruning method to the FLattenTransformer, which includes both attention mechanisms, shows outperformed accuracy for several MACs compared with previous state-of-the-art efficient models and pruned methods. Code will be available soon.
Paper Structure (17 sections, 9 equations, 5 figures, 3 tables)

This paper contains 17 sections, 9 equations, 5 figures, 3 tables.

Figures (5)

  • Figure 1: Problems arising when multi-head is not considered. (a) Multi-head attention is applied to each head, forming different representation subspaces. (b) If multi-head is not considered, reconfiguration leads to channel misalignment problems where representation subspaces get mixed, or subspaces are completely removed. It causes significantly reduction of the expression capacity. (c) Our method resolves these issues through an automatic pruning method by considering multi-head.
  • Figure 2: Multi-head Automatic Pruning process. To consider multi-head, we first undergo the (a) Computing pruning indicator process. We incorporate similarity-based weights into pruning indicators, enabling the pruning indicator to consider saliency channels in each head. (b) Through Pruning indicator adjustment, we share rank-wise pruning indicators for each head. It ensures equal channel removal across all heads, preventing the channel misalignment.
  • Figure 3: Top-3 singular value norms for each attention layer of the pruned model. When using Similarity-based weights, it demonstrates larger singular value norms. It indicates that when employing the proposed method, salient channels can be effectively retained.
  • Figure 4: The effect of the reweight module. The red block represents the query token, while the blue blocks depict the relationship between this query token and other tokens. It demonstrates that the reweight module can compensate for information loss caused by pruning, directing attention to more relevant tokens.
  • Figure 5: Pruned model structure after reconfiguration. (a) Without applying multi-head pruning, the pruning ratio for each head is inconsistent, leading to a channel misalignment problem. (b) When pruning indicator initialization for linear attention is not performed, the attention mechanism does not function properly. (c) Our proposed method ensures consistent pruning ratios across heads to resolve the channel misalignment problem and demonstrates effective pruning of each module in appropriate proportions. Zoom in for details.