Automatic Channel Pruning for Multi-Head Attention
Eunho Lee, Youngbae Hwang
TL;DR
This work tackles the challenge of pruning vision transformers without collapsing the expressive power of multi-head attention. It introduces APMA, a multi-head aware automatic pruning framework that uses similarity-based weights, head-wise pruning synchronization, a reweight module to offset information loss, and a data-driven initialization for linear attention, enabling pruning of both original and linear attention mechanisms. Applied to FlattenTransformer, APMA yields state-of-the-art efficiency-accuracy trade-offs across MACs, with notable throughput gains, and demonstrates robust ablations showing the contribution of each component. The approach advances practical deployment of efficient transformers on large-scale vision benchmarks like ImageNet-1K and sets the stage for extension to detection and segmentation tasks.
Abstract
Despite the strong performance of Transformers, their quadratic computation complexity presents challenges in applying them to vision tasks. Automatic pruning is one of effective methods for reducing computation complexity without heuristic approaches. However, directly applying it to multi-head attention is not straightforward due to channel misalignment. In this paper, we propose an automatic channel pruning method to take into account the multi-head attention mechanism. First, we incorporate channel similarity-based weights into the pruning indicator to preserve more informative channels in each head. Then, we adjust pruning indicator to enforce removal of channels in equal proportions across all heads, preventing the channel misalignment. We also add a reweight module to compensate for information loss resulting from channel removal, and an effective initialization step for pruning indicator based on difference of attention between original structure and each channel. Our proposed method can be used to not only original attention, but also linear attention, which is more efficient as linear complexity with respect to the number of tokens. On ImageNet-1K, applying our pruning method to the FLattenTransformer, which includes both attention mechanisms, shows outperformed accuracy for several MACs compared with previous state-of-the-art efficient models and pruned methods. Code will be available soon.
