Table of Contents
Fetching ...

Keypoint Aware Masked Image Modelling

Madhava Krishna, A V Subramanyam

TL;DR

This work addresses the weakness of SimMIM in linear probing by introducing KAMIM, a patch-wise weighting scheme derived from keypoint density computed via FAST (and alternatives SIFT, ORB). By weighting reconstruction loss with patch reliability, KAMIM delivers substantial gains in linear probing on ImageNet-1K with ViT-B (from 16.12% to 33.97%) and modest improvements in finetuning, while preserving training efficiency. Comprehensive experiments across datasets and architectures show that KAMIM benefits larger pretraining datasets and yields representations with contrastive-like properties, including longer attention distances and global self-attention. The study also analyzes the learned representations through token-level visualization and Fourier analysis, drawing parallels to contrastive learning and revealing attention-collapse-like behavior; the authors provide public code for replication.

Abstract

SimMIM is a widely used method for pretraining vision transformers using masked image modeling. However, despite its success in fine-tuning performance, it has been shown to perform sub-optimally when used for linear probing. We propose an efficient patch-wise weighting derived from keypoint features which captures the local information and provides better context during SimMIM's reconstruction phase. Our method, KAMIM, improves the top-1 linear probing accuracy from 16.12% to 33.97%, and finetuning accuracy from 76.78% to 77.3% when tested on the ImageNet-1K dataset with a ViT-B when trained for the same number of epochs. We conduct extensive testing on different datasets, keypoint extractors, and model architectures and observe that patch-wise weighting augments linear probing performance for larger pretraining datasets. We also analyze the learned representations of a ViT-B trained using KAMIM and observe that they behave similar to contrastive learning with regard to its behavior, with longer attention distances and homogenous self-attention across layers. Our code is publicly available at https://github.com/madhava20217/KAMIM.

Keypoint Aware Masked Image Modelling

TL;DR

This work addresses the weakness of SimMIM in linear probing by introducing KAMIM, a patch-wise weighting scheme derived from keypoint density computed via FAST (and alternatives SIFT, ORB). By weighting reconstruction loss with patch reliability, KAMIM delivers substantial gains in linear probing on ImageNet-1K with ViT-B (from 16.12% to 33.97%) and modest improvements in finetuning, while preserving training efficiency. Comprehensive experiments across datasets and architectures show that KAMIM benefits larger pretraining datasets and yields representations with contrastive-like properties, including longer attention distances and global self-attention. The study also analyzes the learned representations through token-level visualization and Fourier analysis, drawing parallels to contrastive learning and revealing attention-collapse-like behavior; the authors provide public code for replication.

Abstract

SimMIM is a widely used method for pretraining vision transformers using masked image modeling. However, despite its success in fine-tuning performance, it has been shown to perform sub-optimally when used for linear probing. We propose an efficient patch-wise weighting derived from keypoint features which captures the local information and provides better context during SimMIM's reconstruction phase. Our method, KAMIM, improves the top-1 linear probing accuracy from 16.12% to 33.97%, and finetuning accuracy from 76.78% to 77.3% when tested on the ImageNet-1K dataset with a ViT-B when trained for the same number of epochs. We conduct extensive testing on different datasets, keypoint extractors, and model architectures and observe that patch-wise weighting augments linear probing performance for larger pretraining datasets. We also analyze the learned representations of a ViT-B trained using KAMIM and observe that they behave similar to contrastive learning with regard to its behavior, with longer attention distances and homogenous self-attention across layers. Our code is publicly available at https://github.com/madhava20217/KAMIM.
Paper Structure (32 sections, 3 equations, 10 figures, 5 tables)

This paper contains 32 sections, 3 equations, 10 figures, 5 tables.

Figures (10)

  • Figure 1: A diagram depicting the working of KAMIM. We calculate FAST keypoints of an image $\mathbf{I}$ and then compute the density of keypoints over patches of a pre-defined size ($w_{ps}$) to use as a weight $\mathbf{W}$ during the reconstruction phase. An averaging convolutional kernel is used to efficiently calculate the keypoint density and further obtain the weight matrix $\mathbf{W}$. In order to control the weighting factor, a temperature parameter is used. Distinct from SimMIM, an $\ell_1$-loss with weight $\mathbf{W}$ is used on predicted pixel value from a lightweight prediction head.
  • Figure 2: Reconstructed images from SimMIM and KAMIM. The first row depicts the original image, and the second row contains the masked images. We unnormalize the reconstructed images to obtain these visualizations. We see that both methods are close in terms of visual fidelity and it is hard to judge which one is better. SimMIM sometimes performs better, as seen in image 5, while KAMIM does better at capturing details, as in image 1.
  • Figure 3: The t-SNE visualization of a set of token-level representations of images from different classes represented in different colours. The embeddings from the last layer are used. Note that the cls token is dropped, and only the 144 remaining tokens are used for SimMIM and KAMIM, and 196 tokens for MoCo.
  • Figure 4: The plot resulting from the Fourier analysis of the output of intermediate layers in. The hidden states first undergo a 2D Fourier transform, followed by a log amplitude operation and differencing with the first latent to obtain a relative log amplitude plot.
  • Figure 5: The attention maps for the query token (marked with a red box) for SimMIM, KAMIM, and MoCo, obtained from the last layer of a ViT-B. The upper row indicates the base image, and subsequent rows represent the attention maps for SimMIM, KAMIM, and MoCo. The attention maps for SimMIM and KAMIM have $12 \times 12$ tokens while that for MoCo has $14 \times 14$ tokens.
  • ...and 5 more figures