Table of Contents
Fetching ...

Attention Is All You Need For Mixture-of-Depths Routing

Advait Gadhikar, Souptik Kumar Majumdar, Niclas Popp, Piyapat Saranrittichai, Martin Rapp, Lukas Schott

TL;DR

This work targets the computational bottleneck of large transformer models by enhancing Mixture-of-Depths (MoD) with a parameter-free attention-based router, A-MoD, which uses the previous layer's attention maps to decide which tokens to process. By eliminating extra trainable routing parameters, A-MoD enables easier adaptation from pretrained checkpoints and improves stability, achieving up to 2% higher ImageNet accuracy and up to 2x faster transfer learning compared to standard MoD and isoFLOP baselines. Across multiple Vision Transformer variants and datasets, A-MoD consistently steers computation toward the most informative tokens, with routing scores exhibiting higher correlation to leave-one-out token importance than conventional routers. The results demonstrate that attention-based routing can yield better Pareto efficiency (accuracy vs FLOPs) and faster convergence in transfer scenarios, making A-MoD a practical enhancement for scalable, efficient vision models.

Abstract

Advancements in deep learning are driven by training models with increasingly larger numbers of parameters, which in turn heightens the computational demands. To address this issue, Mixture-of-Depths (MoD) models have been proposed to dynamically assign computations only to the most relevant parts of the inputs, thereby enabling the deployment of large-parameter models with high efficiency during inference and training. These MoD models utilize a routing mechanism to determine which tokens should be processed by a layer, or skipped. However, conventional MoD models employ additional network layers specifically for the routing which are difficult to train, and add complexity and deployment overhead to the model. In this paper, we introduce a novel attention-based routing mechanism A-MoD that leverages the existing attention map of the preceding layer for routing decisions within the current layer. Compared to standard routing, A-MoD allows for more efficient training as it introduces no additional trainable parameters and can be easily adapted from pretrained transformer models. Furthermore, it can increase the performance of the MoD model. For instance, we observe up to 2% higher accuracy on ImageNet compared to standard routing and isoFLOP ViT baselines. Furthermore, A-MoD improves the MoD training convergence, leading to up to 2x faster transfer learning.

Attention Is All You Need For Mixture-of-Depths Routing

TL;DR

This work targets the computational bottleneck of large transformer models by enhancing Mixture-of-Depths (MoD) with a parameter-free attention-based router, A-MoD, which uses the previous layer's attention maps to decide which tokens to process. By eliminating extra trainable routing parameters, A-MoD enables easier adaptation from pretrained checkpoints and improves stability, achieving up to 2% higher ImageNet accuracy and up to 2x faster transfer learning compared to standard MoD and isoFLOP baselines. Across multiple Vision Transformer variants and datasets, A-MoD consistently steers computation toward the most informative tokens, with routing scores exhibiting higher correlation to leave-one-out token importance than conventional routers. The results demonstrate that attention-based routing can yield better Pareto efficiency (accuracy vs FLOPs) and faster convergence in transfer scenarios, making A-MoD a practical enhancement for scalable, efficient vision models.

Abstract

Advancements in deep learning are driven by training models with increasingly larger numbers of parameters, which in turn heightens the computational demands. To address this issue, Mixture-of-Depths (MoD) models have been proposed to dynamically assign computations only to the most relevant parts of the inputs, thereby enabling the deployment of large-parameter models with high efficiency during inference and training. These MoD models utilize a routing mechanism to determine which tokens should be processed by a layer, or skipped. However, conventional MoD models employ additional network layers specifically for the routing which are difficult to train, and add complexity and deployment overhead to the model. In this paper, we introduce a novel attention-based routing mechanism A-MoD that leverages the existing attention map of the preceding layer for routing decisions within the current layer. Compared to standard routing, A-MoD allows for more efficient training as it introduces no additional trainable parameters and can be easily adapted from pretrained transformer models. Furthermore, it can increase the performance of the MoD model. For instance, we observe up to 2% higher accuracy on ImageNet compared to standard routing and isoFLOP ViT baselines. Furthermore, A-MoD improves the MoD training convergence, leading to up to 2x faster transfer learning.
Paper Structure (36 sections, 5 equations, 19 figures, 9 tables, 1 algorithm)

This paper contains 36 sections, 5 equations, 19 figures, 9 tables, 1 algorithm.

Figures (19)

  • Figure 1: Accuracy vs FLOPs Pareto-curve for A-MoD in comparison with MoD and ISOFlop models on ImageNet-1k.
  • Figure 2: MoD model (a) with standard routing (b) vs. our A-MoD attention routing (c).
  • Figure 3: A-MoD achieves better performance and faster convergence on ImageNet-1k. Finetuning with A-MoD: Results comparing A-MoD with standard routing and isoFLOP baselines with $50\%$ capacity on ImageNet.
  • Figure 4: A-MoD converges faster across different datasets Transfer learning with A-MoD: A-MoD with $50\%$ capacity MoD trained on the Flower102 dataset. Dotted lines denote the epochs needed to reach within $2\%$ of peak accuracy.
  • Figure 5: A-MoD exhibits more meaningful routing compared to MoD. Routing visualization: Example of DeiT-Small with $50\%$ capacity on ImageNet. Each example shows tokens chosen by standard MoD (top) and A-MoD (bottom) for every MoD layer, white patches denote skipped. Each column represents a MoD layer as depth increases from left to right.
  • ...and 14 more figures