Table of Contents
Fetching ...

Stochastic positional embeddings improve masked image modeling

Amir Bar, Florian Bordes, Assaf Shocher, Mahmoud Assran, Pascal Vincent, Nicolas Ballas, Trevor Darrell, Amir Globerson, Yann LeCun

TL;DR

The paper addresses location uncertainty in Masked Image Modeling by introducing Stochastic Positional Embeddings (StoP), which inject Gaussian noise into masked token positions to prevent overfitting to exact locations. StoP defines $\hat{\psi}_j \sim \mathcal{N}(\psi_j, \Sigma)$ with a learned covariance $\Sigma = \sigma A A^{T}$ and uses a reparameterization $\hat{\psi}_j = A n_j + \psi_j$, coupled with weight tying ($A$ to $B$) to avoid collapse. When applied to I-JEPA, StoP yields consistent improvements on downstream tasks, including $+1.7\%$ on ImageNet linear probing with ViT-B and $+2.5\%$ with ViT-H at $1\%$ labels, and ablations show the importance of applying noise to masked tokens and learning $\Sigma$ rather than fixing it. The approach is lightweight (three extra lines of code) and enhances robustness by promoting spatial smoothing and semantic feature learning, offering practical gains across recognition and dense-prediction tasks.

Abstract

Masked Image Modeling (MIM) is a promising self-supervised learning approach that enables learning from unlabeled images. Despite its recent success, learning good representations through MIM remains challenging because it requires predicting the right semantic content in accurate locations. For example, given an incomplete picture of a dog, we can guess that there is a tail, but we cannot determine its exact location. In this work, we propose to incorporate location uncertainty into MIM by using stochastic positional embeddings (StoP). Specifically, we condition the model on stochastic masked token positions drawn from a Gaussian distribution. StoP reduces overfitting to location features and guides the model toward learning features that are more robust to location uncertainties. Quantitatively, StoP improves downstream MIM performance on a variety of downstream tasks, including $+1.7\%$ on ImageNet linear probing using ViT-B, and $+2.5\%$ for ViT-H using $1\%$ of the data.

Stochastic positional embeddings improve masked image modeling

TL;DR

The paper addresses location uncertainty in Masked Image Modeling by introducing Stochastic Positional Embeddings (StoP), which inject Gaussian noise into masked token positions to prevent overfitting to exact locations. StoP defines with a learned covariance and uses a reparameterization , coupled with weight tying ( to ) to avoid collapse. When applied to I-JEPA, StoP yields consistent improvements on downstream tasks, including on ImageNet linear probing with ViT-B and with ViT-H at labels, and ablations show the importance of applying noise to masked tokens and learning rather than fixing it. The approach is lightweight (three extra lines of code) and enhances robustness by promoting spatial smoothing and semantic feature learning, offering practical gains across recognition and dense-prediction tasks.

Abstract

Masked Image Modeling (MIM) is a promising self-supervised learning approach that enables learning from unlabeled images. Despite its recent success, learning good representations through MIM remains challenging because it requires predicting the right semantic content in accurate locations. For example, given an incomplete picture of a dog, we can guess that there is a tail, but we cannot determine its exact location. In this work, we propose to incorporate location uncertainty into MIM by using stochastic positional embeddings (StoP). Specifically, we condition the model on stochastic masked token positions drawn from a Gaussian distribution. StoP reduces overfitting to location features and guides the model toward learning features that are more robust to location uncertainties. Quantitatively, StoP improves downstream MIM performance on a variety of downstream tasks, including on ImageNet linear probing using ViT-B, and for ViT-H using of the data.
Paper Structure (15 sections, 4 theorems, 21 equations, 6 figures, 16 tables, 1 algorithm)

This paper contains 15 sections, 4 theorems, 21 equations, 6 figures, 16 tables, 1 algorithm.

Key Result

Proposition 3.1

If the weights of $A$ and $B$ are tied (namely $A=B$) then $\left. \frac{dJ_{tied}}{dA} \right|_{A=0} = 0$ iff $\left. \frac{dJ_{det}}{dB} \right|_{B=0} = 0$

Figures (6)

  • Figure 1: Given a partial image of a dog, can you precisely determine the location of its tail? Existing Masked Image Modeling (MIM) models like MAE he2021masked and I-JEPA assran2023self predict tokens deterministically and do not model location uncertainties (a), we propose to predict the target (masked tokens) in stochastic positions (StoP) which prevents overfitting to locations features. StoP leads to improved MIM performance on downstream tasks, including linear probing on ImageNet (b).
  • Figure 2: Masked image modeling using stochastic positional embeddings (StoP).$g_{\phi}$ predicts target tokens given masked tokens with stochastic positions $m_j$ and context tokens $c_i$ obtained via $f_{\theta}$. StoP is applied to masked tokens only, leading to features that are more robust to location uncertainties.
  • Figure 3: Learned vs. predefined stochastic positions. Using the learned covariance matrix as in StoP, e.g, $\Sigma=\sigma AA^T$ leads to $+3.5\%$ improvement compared to smaller gains with a fixed covariance matrix $\Sigma=\sigma I$. Accuracy is reported based on linear probing evaluation using 1% of the data from IN-1k.
  • Figure 4: Increasing $\sigma$ induces regularization. Changing the prior $\sigma$ (where $\Sigma=\sigma AA^T$) induces regularization over $A$ and increases the norm of the masked token, which preserves the masked token information in comparison to the added noise.
  • Figure 5: Similarity matrices of deterministic and stochastic positional embedding (StoP) to a query position. Each row represents the similarity given a different query position. StoP leads to a spatially smooth similarity matrix, thereby making it hard to distinguish the exact location of a given patch.
  • ...and 1 more figures

Theorems & Definitions (6)

  • Proposition 3.1
  • Proposition 3.2
  • Proposition 1.1
  • proof
  • Proposition 1.2
  • proof