ChA-MAEViT: Unifying Channel-Aware Masked Autoencoders and Multi-Channel Vision Transformers for Improved Cross-Channel Learning
Chau Pham, Juan C. Caicedo, Bryan A. Plummer
TL;DR
Multi-Channel Imaging (MCI) presents channels with complementary information, challenging MAEs that assume cross-channel redundancy. ChA-MAEViT introduces Dynamic Channel-Patch Masking, Memory Tokens, a Channel-Aware Decoder, and Hybrid Token Fusion to force cross-channel reconstruction and preserve cross-channel information. Empirical results on CHAMMI, JUMP-CP, and So2Sat show 3.0-21.5% improvements over state-of-the-art MCI-ViTs, with strong robustness to missing channels. The approach offers scalable cross-channel learning for diverse sensors and modalities, with potential extensions to volumetric medical imaging.
Abstract
Prior work using Masked Autoencoders (MAEs) typically relies on random patch masking based on the assumption that images have significant redundancies across different channels, allowing for the reconstruction of masked content using cross-channel correlations. However, this assumption does not hold in Multi-Channel Imaging (MCI), where channels may provide complementary information with minimal feature overlap. Thus, these MAEs primarily learn local structures within individual channels from patch reconstruction, failing to fully leverage cross-channel interactions and limiting their MCI effectiveness. In this paper, we present ChA-MAEViT, an MAE-based method that enhances feature learning across MCI channels via four key strategies: (1) dynamic channel-patch masking, which compels the model to reconstruct missing channels in addition to masked patches, thereby enhancing cross-channel dependencies and improving robustness to varying channel configurations; (2) memory tokens, which serve as long-term memory aids to promote information sharing across channels, addressing the challenges of reconstructing structurally diverse channels; (3) hybrid token fusion module, which merges fine-grained patch tokens with a global class token to capture richer representations; and (4) Channel-Aware Decoder, a lightweight decoder utilizes channel tokens to effectively reconstruct image patches. Experiments on satellite and microscopy datasets, CHAMMI, JUMP-CP, and So2Sat, show that ChA-MAEViT significantly outperforms state-of-the-art MCI-ViTs by 3.0-21.5%, highlighting the importance of cross-channel interactions in MCI. Our code is publicly available at https://github.com/chaudatascience/cha_mae_vit.
