Table of Contents
Fetching ...

Channel Vision Transformers: An Image Is Worth 1 x 16 x 16 Words

Yujia Bao, Srinivasan Sivanandan, Theofanis Karaletsos

TL;DR

ChannelViT advances vision transformers for multi-channel imaging by tokenizing each channel separately and learning channel embeddings, enabling cross-channel and cross-location reasoning. It introduces Hierarchical Channel Sampling (HCS) to regularize training across varying channel subsets, improving robustness when channels are missing at test time. Across ImageNet, JUMP-CP, and So2Sat, ChannelViT consistently outperforms ViT, with HCS enhancing generalization and data efficiency, and providing interpretable channel-wise attention. The approach leverages a shared low-level projection across channels and interpretable channel embeddings, making it practical for diverse, sparsely-sensed multi-channel imaging tasks with real-world constraints.

Abstract

Vision Transformer (ViT) has emerged as a powerful architecture in the realm of modern computer vision. However, its application in certain imaging fields, such as microscopy and satellite imaging, presents unique challenges. In these domains, images often contain multiple channels, each carrying semantically distinct and independent information. Furthermore, the model must demonstrate robustness to sparsity in input channels, as they may not be densely available during training or testing. In this paper, we propose a modification to the ViT architecture that enhances reasoning across the input channels and introduce Hierarchical Channel Sampling (HCS) as an additional regularization technique to ensure robustness when only partial channels are presented during test time. Our proposed model, ChannelViT, constructs patch tokens independently from each input channel and utilizes a learnable channel embedding that is added to the patch tokens, similar to positional embeddings. We evaluate the performance of ChannelViT on ImageNet, JUMP-CP (microscopy cell imaging), and So2Sat (satellite imaging). Our results show that ChannelViT outperforms ViT on classification tasks and generalizes well, even when a subset of input channels is used during testing. Across our experiments, HCS proves to be a powerful regularizer, independent of the architecture employed, suggesting itself as a straightforward technique for robust ViT training. Lastly, we find that ChannelViT generalizes effectively even when there is limited access to all channels during training, highlighting its potential for multi-channel imaging under real-world conditions with sparse sensors. Our code is available at https://github.com/insitro/ChannelViT.

Channel Vision Transformers: An Image Is Worth 1 x 16 x 16 Words

TL;DR

ChannelViT advances vision transformers for multi-channel imaging by tokenizing each channel separately and learning channel embeddings, enabling cross-channel and cross-location reasoning. It introduces Hierarchical Channel Sampling (HCS) to regularize training across varying channel subsets, improving robustness when channels are missing at test time. Across ImageNet, JUMP-CP, and So2Sat, ChannelViT consistently outperforms ViT, with HCS enhancing generalization and data efficiency, and providing interpretable channel-wise attention. The approach leverages a shared low-level projection across channels and interpretable channel embeddings, making it practical for diverse, sparsely-sensed multi-channel imaging tasks with real-world constraints.

Abstract

Vision Transformer (ViT) has emerged as a powerful architecture in the realm of modern computer vision. However, its application in certain imaging fields, such as microscopy and satellite imaging, presents unique challenges. In these domains, images often contain multiple channels, each carrying semantically distinct and independent information. Furthermore, the model must demonstrate robustness to sparsity in input channels, as they may not be densely available during training or testing. In this paper, we propose a modification to the ViT architecture that enhances reasoning across the input channels and introduce Hierarchical Channel Sampling (HCS) as an additional regularization technique to ensure robustness when only partial channels are presented during test time. Our proposed model, ChannelViT, constructs patch tokens independently from each input channel and utilizes a learnable channel embedding that is added to the patch tokens, similar to positional embeddings. We evaluate the performance of ChannelViT on ImageNet, JUMP-CP (microscopy cell imaging), and So2Sat (satellite imaging). Our results show that ChannelViT outperforms ViT on classification tasks and generalizes well, even when a subset of input channels is used during testing. Across our experiments, HCS proves to be a powerful regularizer, independent of the architecture employed, suggesting itself as a straightforward technique for robust ViT training. Lastly, we find that ChannelViT generalizes effectively even when there is limited access to all channels during training, highlighting its potential for multi-channel imaging under real-world conditions with sparse sensors. Our code is available at https://github.com/insitro/ChannelViT.
Paper Structure (58 sections, 9 equations, 11 figures, 17 tables)

This paper contains 58 sections, 9 equations, 11 figures, 17 tables.

Figures (11)

  • Figure 1: Illustration of Channel Vision Transformer (ChannelViT). The input for ChannelViT is a cell image from JUMP-CP, which comprises five fluorescence channels (colored differently) and three brightfield channels (colored in B&W). ChannelViT generates patch tokens for each individual channel, utilizing a learnable channel embedding chn to preserve channel-specific information. The positional embeddings pos and the linear projection $W$ are shared across all channels.
  • Figure 2: Correlation patterns among image channels (left) and the learned channel embeddings (right) for ImageNet, JUMPCP, and So2Sat. ImageNet displays a strong correlation among the three RGB input channels while JUMPCP and So2Sat show minimal correlation between different signal sources (Fluorescence vs. Brightfield, Sentinel 1 vs Sentinel 2).
  • Figure 3: HCS vs. input channel dropout on JUMP-CP (trained on all 8 channels). On the left, we present the accuracy of ViT-S/16 and ChannelViT-S/16 under varying input channel dropout rates and HCS. The accuracy is evaluated across all channel combinations, with the mean accuracy reported for combinations with an equal number of channels (represented on the horizontal axis). On the right, we illustrate the probability distribution of the sampled channel combinations during the training process. We observe 1) ViTs trained with input channel dropout tend to favor channel combinations that are sampled the most; 2) ChannelViT with input channel dropout outperforms ViT with input channel dropout; 3) HCS surpasses input channel dropout in terms of channel robustness.
  • Figure 4: Left: Class-specific relevance attribution of ChannelViT-S/8 for each cell label (perturbed gene) on JUMP-CP. For each perturbed gene (y-axis) and each channel (x-axis), we calculate the maximum attention score, averaged over 100 cells from that specific cell label. This reveals that ChannelViT focuses on different input channels depending on the perturbed gene. Right: A visualization of the relevance heatmaps for both ViT-S/8 (8-channel view) and ChannelViT-S/8 (single-channel view). Both models are trained on JUMP-CP using HCS across all 8 channels. ChannelViT offers interpretability by highlighting the contributions made by each individual channel.
  • Figure 4: Test accuracy of 17-way local climate zone classification on So2Sat. We consider two official splits: random split and city split. Both ViT and ChannelViT are trained on all channels with hierarchical channel sampling. We evaluate their performance on 18 channels (Sentinel 1 & 2) as well as partial channels (Sentinel 1).
  • ...and 6 more figures