Downstream Task Guided Masking Learning in Masked Autoencoders Using Multi-Level Optimization
Han Guo, Ramtin Hosseini, Ruiyi Zhang, Sai Ashish Somayajula, Ranak Roy Chowdhury, Rajesh K. Gupta, Pengtao Xie
TL;DR
This work addresses MAE's uniform patch masking by introducing MLO-MAE, which learns a masking strategy guided by downstream task feedback through a three-level, end-to-end optimization framework. The encoder, a learnable masking network, and a lightweight classifier are trained in stages that interdepend via hypergradients and implicit differentiation, enabling masks to focus on information-rich patches relevant to the downstream task. Empirical results across CIFAR-10/100, ImageNet, and transfer tasks (fine-grained classification, semantic segmentation, and object detection) show substantial improvements over MAE-based baselines and demonstrate robust transferability. The approach balances higher computational cost with faster convergence (50 epochs) and yields practically significant gains for representation learning with broad applicability and flexibility in continued pretraining and downstream feedback settings.
Abstract
Masked Autoencoder (MAE) is a notable method for self-supervised pretraining in visual representation learning. It operates by randomly masking image patches and reconstructing these masked patches using the unmasked ones. A key limitation of MAE lies in its disregard for the varying informativeness of different patches, as it uniformly selects patches to mask. To overcome this, some approaches propose masking based on patch informativeness. However, these methods often do not consider the specific requirements of downstream tasks, potentially leading to suboptimal representations for these tasks. In response, we introduce the Multi-level Optimized Mask Autoencoder (MLO-MAE), a novel framework that leverages end-to-end feedback from downstream tasks to learn an optimal masking strategy during pretraining. Our experimental findings highlight MLO-MAE's significant advancements in visual representation learning. Compared to existing methods, it demonstrates remarkable improvements across diverse datasets and tasks, showcasing its adaptability and efficiency.
