Table of Contents
Fetching ...

Token-UNet: A New Case for Transformers Integration in Efficient and Interpretable 3D UNets for Brain Imaging Segmentation

Louis Fabrice Tshimanga, Andrea Zanola, Federico Del Pup, Manfredo Atzori

TL;DR

This work reconsiders the role of convolution and attention, introducing Token-UNets, a family of 3D segmentation models that can operate in constrained computational environments and time frames and maintains the convolutional encoder of UNet-like models, and applies TokenLearner to 3D feature maps.

Abstract

We present Token-UNet, adopting the TokenLearner and TokenFuser modules to encase Transformers into UNets. While Transformers have enabled global interactions among input elements in medical imaging, current computational challenges hinder their deployment on common hardware. Models like (Swin)UNETR adapt the UNet architecture by incorporating (Swin)Transformer encoders, which process tokens that each represent small subvolumes ($8^3$ voxels) of the input. The Transformer attention mechanism scales quadratically with the number of tokens, which is tied to the cubic scaling of 3D input resolution. This work reconsiders the role of convolution and attention, introducing Token-UNets, a family of 3D segmentation models that can operate in constrained computational environments and time frames. To mitigate computational demands, our approach maintains the convolutional encoder of UNet-like models, and applies TokenLearner to 3D feature maps. This module pools a preset number of tokens from local and global structures. Our results show this tokenization effectively encodes task-relevant information, yielding naturally interpretable attention maps. The memory footprint, computation times at inference, and parameter counts of our heaviest model are reduced to 33\%, 10\%, and 35\% of the SwinUNETR values, with better average performance (86.75\% $\pm 0.19\%$ Dice score for SwinUNETR vs our 87.21\% $\pm 0.35\%$). This work opens the way to more efficient trainings in contexts with limited computational resources, such as 3D medical imaging. Easing model optimization, fine-tuning, and transfer-learning in limited hardware settings can accelerate and diversify the development of approaches, for the benefit of the research community.

Token-UNet: A New Case for Transformers Integration in Efficient and Interpretable 3D UNets for Brain Imaging Segmentation

TL;DR

This work reconsiders the role of convolution and attention, introducing Token-UNets, a family of 3D segmentation models that can operate in constrained computational environments and time frames and maintains the convolutional encoder of UNet-like models, and applies TokenLearner to 3D feature maps.

Abstract

We present Token-UNet, adopting the TokenLearner and TokenFuser modules to encase Transformers into UNets. While Transformers have enabled global interactions among input elements in medical imaging, current computational challenges hinder their deployment on common hardware. Models like (Swin)UNETR adapt the UNet architecture by incorporating (Swin)Transformer encoders, which process tokens that each represent small subvolumes ( voxels) of the input. The Transformer attention mechanism scales quadratically with the number of tokens, which is tied to the cubic scaling of 3D input resolution. This work reconsiders the role of convolution and attention, introducing Token-UNets, a family of 3D segmentation models that can operate in constrained computational environments and time frames. To mitigate computational demands, our approach maintains the convolutional encoder of UNet-like models, and applies TokenLearner to 3D feature maps. This module pools a preset number of tokens from local and global structures. Our results show this tokenization effectively encodes task-relevant information, yielding naturally interpretable attention maps. The memory footprint, computation times at inference, and parameter counts of our heaviest model are reduced to 33\%, 10\%, and 35\% of the SwinUNETR values, with better average performance (86.75\% Dice score for SwinUNETR vs our 87.21\% ). This work opens the way to more efficient trainings in contexts with limited computational resources, such as 3D medical imaging. Easing model optimization, fine-tuning, and transfer-learning in limited hardware settings can accelerate and diversify the development of approaches, for the benefit of the research community.
Paper Structure (17 sections, 6 equations, 7 figures, 2 tables)

This paper contains 17 sections, 6 equations, 7 figures, 2 tables.

Figures (7)

  • Figure 1: Architecture of Token-UNet including the encased Transformer.
  • Figure 2: General form of our ResBlocks, with options to upsample or downsample feature maps for the residual paths (with strided convolutions or transposed convolutions) and internal skip-connections (with average pooling or linear upsampling.
  • Figure 3: The last encoded cube is fed to TokenLearner and transformed into N=8 different pooled vectors that encode semantic information from all pertinent locations. These vectors eventually serve as tokens in the Transformer. TokenFuser than recreates N=8 spatial masks for tokens, that are linearly mixed, then broadcast over the masks and finally summed to the last encoded cube, that will be decoded by the ascending UNet path.
  • Figure 4: Inference time vs inference memory occupation on GPU.
  • Figure 5: Loss history by fold for every architecture.
  • ...and 2 more figures