Table of Contents
Fetching ...

ChA-MAEViT: Unifying Channel-Aware Masked Autoencoders and Multi-Channel Vision Transformers for Improved Cross-Channel Learning

Chau Pham, Juan C. Caicedo, Bryan A. Plummer

TL;DR

Multi-Channel Imaging (MCI) presents channels with complementary information, challenging MAEs that assume cross-channel redundancy. ChA-MAEViT introduces Dynamic Channel-Patch Masking, Memory Tokens, a Channel-Aware Decoder, and Hybrid Token Fusion to force cross-channel reconstruction and preserve cross-channel information. Empirical results on CHAMMI, JUMP-CP, and So2Sat show 3.0-21.5% improvements over state-of-the-art MCI-ViTs, with strong robustness to missing channels. The approach offers scalable cross-channel learning for diverse sensors and modalities, with potential extensions to volumetric medical imaging.

Abstract

Prior work using Masked Autoencoders (MAEs) typically relies on random patch masking based on the assumption that images have significant redundancies across different channels, allowing for the reconstruction of masked content using cross-channel correlations. However, this assumption does not hold in Multi-Channel Imaging (MCI), where channels may provide complementary information with minimal feature overlap. Thus, these MAEs primarily learn local structures within individual channels from patch reconstruction, failing to fully leverage cross-channel interactions and limiting their MCI effectiveness. In this paper, we present ChA-MAEViT, an MAE-based method that enhances feature learning across MCI channels via four key strategies: (1) dynamic channel-patch masking, which compels the model to reconstruct missing channels in addition to masked patches, thereby enhancing cross-channel dependencies and improving robustness to varying channel configurations; (2) memory tokens, which serve as long-term memory aids to promote information sharing across channels, addressing the challenges of reconstructing structurally diverse channels; (3) hybrid token fusion module, which merges fine-grained patch tokens with a global class token to capture richer representations; and (4) Channel-Aware Decoder, a lightweight decoder utilizes channel tokens to effectively reconstruct image patches. Experiments on satellite and microscopy datasets, CHAMMI, JUMP-CP, and So2Sat, show that ChA-MAEViT significantly outperforms state-of-the-art MCI-ViTs by 3.0-21.5%, highlighting the importance of cross-channel interactions in MCI. Our code is publicly available at https://github.com/chaudatascience/cha_mae_vit.

ChA-MAEViT: Unifying Channel-Aware Masked Autoencoders and Multi-Channel Vision Transformers for Improved Cross-Channel Learning

TL;DR

Multi-Channel Imaging (MCI) presents channels with complementary information, challenging MAEs that assume cross-channel redundancy. ChA-MAEViT introduces Dynamic Channel-Patch Masking, Memory Tokens, a Channel-Aware Decoder, and Hybrid Token Fusion to force cross-channel reconstruction and preserve cross-channel information. Empirical results on CHAMMI, JUMP-CP, and So2Sat show 3.0-21.5% improvements over state-of-the-art MCI-ViTs, with strong robustness to missing channels. The approach offers scalable cross-channel learning for diverse sensors and modalities, with potential extensions to volumetric medical imaging.

Abstract

Prior work using Masked Autoencoders (MAEs) typically relies on random patch masking based on the assumption that images have significant redundancies across different channels, allowing for the reconstruction of masked content using cross-channel correlations. However, this assumption does not hold in Multi-Channel Imaging (MCI), where channels may provide complementary information with minimal feature overlap. Thus, these MAEs primarily learn local structures within individual channels from patch reconstruction, failing to fully leverage cross-channel interactions and limiting their MCI effectiveness. In this paper, we present ChA-MAEViT, an MAE-based method that enhances feature learning across MCI channels via four key strategies: (1) dynamic channel-patch masking, which compels the model to reconstruct missing channels in addition to masked patches, thereby enhancing cross-channel dependencies and improving robustness to varying channel configurations; (2) memory tokens, which serve as long-term memory aids to promote information sharing across channels, addressing the challenges of reconstructing structurally diverse channels; (3) hybrid token fusion module, which merges fine-grained patch tokens with a global class token to capture richer representations; and (4) Channel-Aware Decoder, a lightweight decoder utilizes channel tokens to effectively reconstruct image patches. Experiments on satellite and microscopy datasets, CHAMMI, JUMP-CP, and So2Sat, show that ChA-MAEViT significantly outperforms state-of-the-art MCI-ViTs by 3.0-21.5%, highlighting the importance of cross-channel interactions in MCI. Our code is publicly available at https://github.com/chaudatascience/cha_mae_vit.

Paper Structure

This paper contains 27 sections, 3 equations, 17 figures, 10 tables, 1 algorithm.

Figures (17)

  • Figure 1: Image patch interactions in MCI.(A) Prior Work on MCI-MAEs (e.g., CA-MAE kraus2024masked) employ random patch masking to train the model to reconstruct the masked patches. In the attention map, where each row represents the average attention score of all patches in a channel towards other channels and the [CLS] token (last column), we show each patch primarily attends to its own channel (the diagonal) and the [CLS] token. This suggests that patch masking may not effectively promote cross-channel interactions in MCI. (B) In contrast, Dynamic Channel-Patch Masking (ours) encourages more interactions between patches across different channels by using both channel and patch masking. Also, memory tokens serve as long-term memory to help information aggregation across channels. Its attention pattern demonstrates a more uniform distribution across channels, indicating that each image patch can learn more meaningful interactions.
  • Figure 2: Our ChA-MAEViT approach enhances cross-channel learning via four key components: 1) Dynamic Channel-Patch Masking, which compels the model to reconstruct varying proportions of missing channels and patches, thus improving interactions across channels and robustness to the absence of some channels (\ref{['sec:mask_strategy']}). 2) Memory Tokens, which act as long-term memory to facilitate information sharing between channels (\ref{['sec:memory_tokens']}). 3) To reconstruct the masked patches and channels, we use a Channel-Aware Decoder that leverages channel tokens for image reconstruction, enhancing performance while minimizing computational costs (\ref{['sec:channel_aware_decoder']}). 4) A Hybrid Token Fusion module, which combines fine-grained patch tokens with a global [CLS] token to improve feature representation (\ref{['sec:hybrid_token_fusion']}).
  • Figure 3: Impact of the number memory tokens and reconstruction lambda $\lambda_{\mathrm{recon}}$ (\ref{['eq:finalloss']}).(a) & (b) Using $4-8$ tokens improves performance, however, using more memory tokens (e.g., $24$) may reduce the effectiveness. (c) & (d)$\lambda_{\mathrm{recon}}=0$ means without the reconstruction loss, while $\lambda_{\mathrm{recon}}=1$ indicates only using the reconstruction loss. For $\lambda_{\mathrm{recon}}=1$ on So2Sat, we run linear probing. $\lambda_{\mathrm{recon}}=0.99$ works best for ChA-MAEViT on both datasets.
  • Figure 4: Attention between image patches and memory tokens of the encoder. Each channel group focuses on different memory tokens. (a) So2Sat:VH channels utilize memory token $8$, whereas Lee-filtered channels attend more to memory token $1$. (b) JUMP-CP:Brightfield channels focus on memory token $3$, while Fluorescence channels favor memory token $1$.
  • Figure 5: Different Decoders when using with ChA-MAEViT. Our Channel-Aware Decoder outperforms the best baseline by $1.1-3.2\%$ on all three datasets.
  • ...and 12 more figures