Table of Contents
Fetching ...

ThinkingViT: Matryoshka Thinking Vision Transformer for Elastic Inference

Ali Hojjat, Janek Haberer, Soren Pirk, Olaf Landsiedel

TL;DR

ThinkingViT tackles the inefficiency of fixed-budget Vision Transformers by introducing elastic, input-aware computation through nested subnetworks and progressive thinking. It activates a small set of attention heads first and only expands computation when confidence is insufficient, using Token Recycling to fuse embeddings across rounds. Training jointly across all subnetworks with a weighted loss and entropy-based early stopping enables accurate predictions with fewer GMACs and flexible throughput, outperforming fixed-budget and other nested baselines. The approach preserves the backbone while enabling transfer to downstream tasks like semantic segmentation and to hierarchical architectures such as Swin, offering practical deployment benefits across diverse hardware budgets.

Abstract

ViTs deliver SOTA performance, yet their fixed computational budget prevents scalable deployment across heterogeneous hardware. Recent Matryoshka-style Transformer architectures mitigate this by embedding nested subnetworks within a single model to enable scalable inference. However, these models allocate the same amount of compute to all inputs, regardless of their complexity, which leads to inefficiencies. To address this, we introduce ThinkingViT, a nested ViT architecture that employs progressive thinking stages to dynamically adjust inference computation based on input difficulty. ThinkingViT first activates a small subset of the most important attention heads to produce an initial prediction. If the prediction confidence exceeds a predefined threshold, inference terminates early. Otherwise, within the same backbone, it activates a larger subset of attention heads and conducts a new forward pass. This process continues iteratively until the model reaches the predefined confidence level or exhausts its maximum capacity. To boost the performance of subsequent rounds, we introduce a Token Recycling approach that fuses the input embeddings with the embeddings from the previous stage. Experiments show that ThinkingViT surpasses nested baselines by up to 2.0 percentage points (p.p.) in accuracy at the same throughput and by up to 2.9 p.p. at equal GMACs on ImageNet-1K. We show that the backbone-preserving design of ThinkingViT allows it to serve as a plug-in upgrade for ViTs in downstream tasks such as semantic segmentation. We also demonstrate that ThinkingViT transfers effectively to other architectures such as Swin. The source code is available at https://github.com/ds-kiel/ThinkingViT.

ThinkingViT: Matryoshka Thinking Vision Transformer for Elastic Inference

TL;DR

ThinkingViT tackles the inefficiency of fixed-budget Vision Transformers by introducing elastic, input-aware computation through nested subnetworks and progressive thinking. It activates a small set of attention heads first and only expands computation when confidence is insufficient, using Token Recycling to fuse embeddings across rounds. Training jointly across all subnetworks with a weighted loss and entropy-based early stopping enables accurate predictions with fewer GMACs and flexible throughput, outperforming fixed-budget and other nested baselines. The approach preserves the backbone while enabling transfer to downstream tasks like semantic segmentation and to hierarchical architectures such as Swin, offering practical deployment benefits across diverse hardware budgets.

Abstract

ViTs deliver SOTA performance, yet their fixed computational budget prevents scalable deployment across heterogeneous hardware. Recent Matryoshka-style Transformer architectures mitigate this by embedding nested subnetworks within a single model to enable scalable inference. However, these models allocate the same amount of compute to all inputs, regardless of their complexity, which leads to inefficiencies. To address this, we introduce ThinkingViT, a nested ViT architecture that employs progressive thinking stages to dynamically adjust inference computation based on input difficulty. ThinkingViT first activates a small subset of the most important attention heads to produce an initial prediction. If the prediction confidence exceeds a predefined threshold, inference terminates early. Otherwise, within the same backbone, it activates a larger subset of attention heads and conducts a new forward pass. This process continues iteratively until the model reaches the predefined confidence level or exhausts its maximum capacity. To boost the performance of subsequent rounds, we introduce a Token Recycling approach that fuses the input embeddings with the embeddings from the previous stage. Experiments show that ThinkingViT surpasses nested baselines by up to 2.0 percentage points (p.p.) in accuracy at the same throughput and by up to 2.9 p.p. at equal GMACs on ImageNet-1K. We show that the backbone-preserving design of ThinkingViT allows it to serve as a plug-in upgrade for ViTs in downstream tasks such as semantic segmentation. We also demonstrate that ThinkingViT transfers effectively to other architectures such as Swin. The source code is available at https://github.com/ds-kiel/ThinkingViT.

Paper Structure

This paper contains 29 sections, 6 equations, 16 figures, 7 tables, 1 algorithm.

Figures (16)

  • Figure 1: Comparison of ThinkingViT with MatFormer (NeurIPS'24)devvrit2024matformer, HydraViT (NeurIPS'24)haberer2024hydravit, SortedNet (NeurIPS'23-W)valipour2023sortednet, and DynaBERT (NeurIPS'20)hou2020dynabert, evaluated in terms of GMACs and throughput on an A100. All baselines have a dynamic width within a standard backbone and are trained following the training recipes in touvron2021training. ThinkingViT is trained with two progressive thinking stages using 3 and 6 heads, and consistently surpasses baseline by up to 2.0 p.p. at the same throughput and by up to 2.9 p.p. at the same GMACs. See Appendix \ref{['appendix_gmacs_imagenet_val_for_small']} for a comparison with smaller baseline models.
  • Figure 2: Nested progressive inference with Token Recycling in ThinkingViT. After embedding the input, ThinkingViT first activates a subset of the model, including the first attention heads (e.g., 50%), to produce an initial prediction. Due to the training procedure, these heads capture the most important features. If the certainty exceeds a threshold (easy inputs), inference terminates early to save computation. Otherwise, the resulting tokens are fused back into the input via a learnable scaling factor $\alpha$, which controls how much prior knowledge is recycled. The model then thinks more by reprocessing the fused tokens using a larger subset of the attention heads (e.g., 100%) for a refined prediction. ThinkingViT enables elastic inference across different hardware budgets by adjusting the confidence threshold. The number of thinking stages and the proportion of attention heads activated at each stage can be flexibly configured based on efficiency and accuracy trade-offs, see Section \ref{['sec:trade_offs']}.
  • Figure 3: Accuracy vs. GMACs for ThinkingViT variants on ImageNet-1K. 3H $\rightarrow$ 6H achieves the best trade-off, while 2H $\rightarrow$ 3H $\rightarrow$ 6H covers the widest range with only a slight accuracy drop. See Appendix \ref{['sec:thinking_ablation']} for details.
  • Figure 4: GMACs vs. Accuracy on ImageNet variants. ThinkingViT has superior performance compared to the baselines.
  • Figure 5: Visualization of images sorted by first-round entropy. ThinkingViT confidently classifies simple, clear images in one round, while complex cases with occlusion or clutter show higher entropy and trigger a second round.
  • ...and 11 more figures