Table of Contents
Fetching ...

PR-MIM: Delving Deeper into Partial Reconstruction in Masked Image Modeling

Zhong-Yu Li, Yunheng Li, Deng-Ping Fan, Ming-Ming Cheng

TL;DR

This work proposes a progressive reconstruction strategy and a furthest sampling strategy to reconstruct those thrown tokens in an extremely lightweight way instead of completely abandoning them and validates the effectiveness of the proposed method across various existing frameworks.

Abstract

Masked image modeling has achieved great success in learning representations but is limited by the huge computational costs. One cost-saving strategy makes the decoder reconstruct only a subset of masked tokens and throw the others, and we refer to this method as partial reconstruction. However, it also degrades the representation quality. Previous methods mitigate this issue by throwing tokens with minimal information using temporal redundancy inaccessible for static images or attention maps that incur extra costs and complexity. To address these limitations, we propose a progressive reconstruction strategy and a furthest sampling strategy to reconstruct those thrown tokens in an extremely lightweight way instead of completely abandoning them. This approach involves all masked tokens in supervision to ensure adequate pre-training, while maintaining the cost-reduction benefits of partial reconstruction. We validate the effectiveness of the proposed method across various existing frameworks. For example, when throwing 50% patches, we can achieve lossless performance of the ViT-B/16 while saving 28% FLOPs and 36% memory usage compared to standard MAE. Our source code will be made publicly available

PR-MIM: Delving Deeper into Partial Reconstruction in Masked Image Modeling

TL;DR

This work proposes a progressive reconstruction strategy and a furthest sampling strategy to reconstruct those thrown tokens in an extremely lightweight way instead of completely abandoning them and validates the effectiveness of the proposed method across various existing frameworks.

Abstract

Masked image modeling has achieved great success in learning representations but is limited by the huge computational costs. One cost-saving strategy makes the decoder reconstruct only a subset of masked tokens and throw the others, and we refer to this method as partial reconstruction. However, it also degrades the representation quality. Previous methods mitigate this issue by throwing tokens with minimal information using temporal redundancy inaccessible for static images or attention maps that incur extra costs and complexity. To address these limitations, we propose a progressive reconstruction strategy and a furthest sampling strategy to reconstruct those thrown tokens in an extremely lightweight way instead of completely abandoning them. This approach involves all masked tokens in supervision to ensure adequate pre-training, while maintaining the cost-reduction benefits of partial reconstruction. We validate the effectiveness of the proposed method across various existing frameworks. For example, when throwing 50% patches, we can achieve lossless performance of the ViT-B/16 while saving 28% FLOPs and 36% memory usage compared to standard MAE. Our source code will be made publicly available

Paper Structure

This paper contains 14 sections, 1 equation, 7 figures, 14 tables, 1 algorithm.

Figures (7)

  • Figure 1: Performance on ImageNet-1K. The bubble area is proportional to the training FLOPs. A higher throwing ratio indicates more tokens are thrown with lower training costs but leads to greater degradation in plain partial reconstruction.
  • Figure 2: Different strategies for masked image modeling and the corresponding FLOPs of the decoder. Partial reconstruction throws a part of masked tokens $x_t$ and retains the others $x_d$ (marked by $\rm M$), while our progressive reconstruction scheme proposed in Sec. \ref{['sec:progressive_reconstruction']} reconstructs the thrown tokens with minimal computational costs to ensure adequate training.
  • Figure 3: When partial reconstruction throws a subset of masked tokens, the proposed progressive reconstruction scheme reconstructs each masked token with minimal additional costs.
  • Figure 4: The $L_2$ norm of gradient difference between different methods and standard MAE.
  • Figure 5: The furthest sampling strategy. The red box and arrows mean that the thrown token at the center is reconstructed by aggregating information from the other tokens within the box.
  • ...and 2 more figures