Table of Contents
Fetching ...

Compute-Efficient Medical Image Classification with Softmax-Free Transformers and Sequence Normalization

Firas Khader, Omar S. M. El Nahhas, Tianyu Han, Gustav Müller-Franzes, Sven Nebelung, Jakob Nikolas Kather, Daniel Truhn

TL;DR

This work tackles the quadratic cost of self-attention in Transformer models for high-resolution medical images. It introduces a softmax-free attention mechanism coupled with sequence normalization, removing the softmax operation and achieving linear scaling with sequence length by normalizing Q, K, V and applying learnable affine transforms and a $1/N$ scaling. Across five diverse medical imaging datasets (fundoscopic, dermoscopic, chest radiographs, breast MRI, and whole-slide images), the method attains performance close to traditional ViT models and consistently outperforms the SimA baseline, especially at higher resolutions. The results, along with finetuning experiments and scaling analyses, demonstrate practical efficiency gains suitable for edge devices without sacrificing accuracy, highlighting a broadly applicable approach to scalable medical image classification.

Abstract

The Transformer model has been pivotal in advancing fields such as natural language processing, speech recognition, and computer vision. However, a critical limitation of this model is its quadratic computational and memory complexity relative to the sequence length, which constrains its application to longer sequences. This is especially crucial in medical imaging where high-resolution images can reach gigapixel scale. Efforts to address this issue have predominantely focused on complex techniques, such as decomposing the softmax operation integral to the Transformer's architecture. This paper addresses this quadratic computational complexity of Transformer models and introduces a remarkably simple and effective method that circumvents this issue by eliminating the softmax function from the attention mechanism and adopting a sequence normalization technique for the key, query, and value tokens. Coupled with a reordering of matrix multiplications this approach reduces the memory- and compute complexity to a linear scale. We evaluate this approach across various medical imaging datasets comprising fundoscopic, dermascopic, radiologic and histologic imaging data. Our findings highlight that these models exhibit a comparable performance to traditional transformer models, while efficiently handling longer sequences.

Compute-Efficient Medical Image Classification with Softmax-Free Transformers and Sequence Normalization

TL;DR

This work tackles the quadratic cost of self-attention in Transformer models for high-resolution medical images. It introduces a softmax-free attention mechanism coupled with sequence normalization, removing the softmax operation and achieving linear scaling with sequence length by normalizing Q, K, V and applying learnable affine transforms and a scaling. Across five diverse medical imaging datasets (fundoscopic, dermoscopic, chest radiographs, breast MRI, and whole-slide images), the method attains performance close to traditional ViT models and consistently outperforms the SimA baseline, especially at higher resolutions. The results, along with finetuning experiments and scaling analyses, demonstrate practical efficiency gains suitable for edge devices without sacrificing accuracy, highlighting a broadly applicable approach to scalable medical image classification.

Abstract

The Transformer model has been pivotal in advancing fields such as natural language processing, speech recognition, and computer vision. However, a critical limitation of this model is its quadratic computational and memory complexity relative to the sequence length, which constrains its application to longer sequences. This is especially crucial in medical imaging where high-resolution images can reach gigapixel scale. Efforts to address this issue have predominantely focused on complex techniques, such as decomposing the softmax operation integral to the Transformer's architecture. This paper addresses this quadratic computational complexity of Transformer models and introduces a remarkably simple and effective method that circumvents this issue by eliminating the softmax function from the attention mechanism and adopting a sequence normalization technique for the key, query, and value tokens. Coupled with a reordering of matrix multiplications this approach reduces the memory- and compute complexity to a linear scale. We evaluate this approach across various medical imaging datasets comprising fundoscopic, dermascopic, radiologic and histologic imaging data. Our findings highlight that these models exhibit a comparable performance to traditional transformer models, while efficiently handling longer sequences.
Paper Structure (16 sections, 3 equations, 2 figures, 2 tables)

This paper contains 16 sections, 3 equations, 2 figures, 2 tables.

Figures (2)

  • Figure 1: Illustration of the attention computation mechanism (left) and sequence normalization approach (right) used in our study. Note that sequence normalization resembles an instance normalization step across the sequence dimension.
  • Figure 2: Scaling behavior of the three models with respect to the input resolution. We find that our approach, as well as the SimA model scale much more efficiently in terms of epoch duration time (left) as well as GPU VRAM consumption (right) for the 15,000 images contained in the VinDr-CXR dataset (left). Note that a batch size of 1 was utilized for all runs. Evaluations for the ViT at image resolutions of $2048\times2048$ were omitted as the model did not fit into the 48GB GPU memory.