Table of Contents
Fetching ...

Adaptive Computation Modules: Granular Conditional Computation For Efficient Inference

Bartosz Wójcik, Alessio Devoto, Karol Pustelnik, Pasquale Minervini, Simone Scardapane

TL;DR

The paper introduces Adaptive Computation Modules (ACMs), a granular conditional computation approach that allocates per-token compute by stacking a sequence of learners and gating how many are executed for each token. It outlines a distillation-based initialization strategy to replace pre-trained blocks with ACMs and a three-phase training pipeline (distillation, gating pre-training, and end-to-end fine-tuning) augmented by auxiliary losses to encourage diverse and semantically meaningful compute. Empirical results in vision and speech show ACMs significantly reduce inference costs while preserving downstream accuracy, achieving favorable cost-accuracy trade-offs across a wide range of budgets and demonstrating hardware-efficient implementations. The work offers a practical, plug-and-play path to more energy-efficient transformers and points to future directions such as pruning and quantization to push further gains.

Abstract

While transformer models have been highly successful, they are computationally inefficient. We observe that for each layer, the full width of the layer may be needed only for a small subset of tokens inside a batch and that the "effective" width needed to process a token can vary from layer to layer. Motivated by this observation, we introduce the Adaptive Computation Module (ACM), a generic module that dynamically adapts its computational load to match the estimated difficulty of the input on a per-token basis. An ACM consists of a sequence of learners that progressively refine the output of their preceding counterparts. An additional gating mechanism determines the optimal number of learners to execute for each token. We also propose a distillation technique to replace any pre-trained model with an "ACMized" variant. Our evaluation of transformer models in computer vision and speech recognition demonstrates that substituting layers with ACMs significantly reduces inference costs without degrading the downstream accuracy for a wide interval of user-defined budgets.

Adaptive Computation Modules: Granular Conditional Computation For Efficient Inference

TL;DR

The paper introduces Adaptive Computation Modules (ACMs), a granular conditional computation approach that allocates per-token compute by stacking a sequence of learners and gating how many are executed for each token. It outlines a distillation-based initialization strategy to replace pre-trained blocks with ACMs and a three-phase training pipeline (distillation, gating pre-training, and end-to-end fine-tuning) augmented by auxiliary losses to encourage diverse and semantically meaningful compute. Empirical results in vision and speech show ACMs significantly reduce inference costs while preserving downstream accuracy, achieving favorable cost-accuracy trade-offs across a wide range of budgets and demonstrating hardware-efficient implementations. The work offers a practical, plug-and-play path to more energy-efficient transformers and points to future directions such as pruning and quantization to push further gains.

Abstract

While transformer models have been highly successful, they are computationally inefficient. We observe that for each layer, the full width of the layer may be needed only for a small subset of tokens inside a batch and that the "effective" width needed to process a token can vary from layer to layer. Motivated by this observation, we introduce the Adaptive Computation Module (ACM), a generic module that dynamically adapts its computational load to match the estimated difficulty of the input on a per-token basis. An ACM consists of a sequence of learners that progressively refine the output of their preceding counterparts. An additional gating mechanism determines the optimal number of learners to execute for each token. We also propose a distillation technique to replace any pre-trained model with an "ACMized" variant. Our evaluation of transformer models in computer vision and speech recognition demonstrates that substituting layers with ACMs significantly reduces inference costs without degrading the downstream accuracy for a wide interval of user-defined budgets.
Paper Structure (21 sections, 12 equations, 9 figures)

This paper contains 21 sections, 12 equations, 9 figures.

Figures (9)

  • Figure 1: ACMs adapt their computational load for each input on a per-token basis by selecting the number of learners to execute via a trainable gate. In the example on the top, a background token (green) is allocated fewer learners than a content-rich token (orange). This results in a spatially varying computational load, as shown on the bottom.
  • Figure 2: Architecture of an ACM block: the output is the sum of $k$learners, where $k$ is determined on a per-token basis by a small gating network $g$. The learners are executed in parallel. In the example, only the first two learners are executed, and the computation of the third (greyed out) is skipped.
  • Figure 3: Performance-efficiency trade-offs of different conditional computation methods as measured on the ImageNet-1k dataset. ACM-based ViT-B achieves the Pareto frontier for a wide range of computational budgets.
  • Figure 4: Performance-efficiency trade-offs of different conditional computation methods as measured on the CommonVoice-es dataset. The model's performance is reported in terms of Word Error Rate (WER). ACMs achieve lower WER for every evaluated computational budget.
  • Figure 5: Color-coded errors of a $4$-learner ACM plotted after performing module-wise representation distillation for modules from eight block of a ViT-B pre-trained model. Tokens are sorted along the Y-axis of the plot by their average error. For most input tokens, the same transformation can be performed by a considerably smaller module consisting of only two or three learners, thus justifying the use of ACMs.
  • ...and 4 more figures