Table of Contents
Fetching ...

Two-stage Vision Transformers and Hard Masking offer Robust Object Representations

Ananthu Aniraj, Cassio F. Dantas, Dino Ienco, Diego Marcos

Abstract

Context can strongly affect object representations, sometimes leading to undesired biases, particularly when objects appear in out-of-distribution backgrounds at inference. At the same time, many object-centric tasks require to leverage the context for identifying the relevant image regions. We posit that this conundrum, in which context is simultaneously needed and a potential nuisance, can be addressed by an attention-based approach that uses learned binary attention masks to ensure that only attended image regions influence the prediction. To test this hypothesis, we evaluate a two-stage framework: stage 1 processes the full image to discover object parts and identify task-relevant regions, for which context cues are likely to be needed, while stage 2 leverages input attention masking to restrict its receptive field to these regions, enabling a focused analysis while filtering out potentially spurious information. Both stages are trained jointly, allowing stage 2 to refine stage 1. The explicit nature of the semantic masks also makes the model's reasoning auditable, enabling powerful test-time interventions to further enhance robustness. Extensive experiments across diverse benchmarks demonstrate that this approach significantly improves robustness against spurious correlations and out-of-distribution backgrounds. Code: https://github.com/ananthu-aniraj/ifam

Two-stage Vision Transformers and Hard Masking offer Robust Object Representations

Abstract

Context can strongly affect object representations, sometimes leading to undesired biases, particularly when objects appear in out-of-distribution backgrounds at inference. At the same time, many object-centric tasks require to leverage the context for identifying the relevant image regions. We posit that this conundrum, in which context is simultaneously needed and a potential nuisance, can be addressed by an attention-based approach that uses learned binary attention masks to ensure that only attended image regions influence the prediction. To test this hypothesis, we evaluate a two-stage framework: stage 1 processes the full image to discover object parts and identify task-relevant regions, for which context cues are likely to be needed, while stage 2 leverages input attention masking to restrict its receptive field to these regions, enabling a focused analysis while filtering out potentially spurious information. Both stages are trained jointly, allowing stage 2 to refine stage 1. The explicit nature of the semantic masks also makes the model's reasoning auditable, enabling powerful test-time interventions to further enhance robustness. Extensive experiments across diverse benchmarks demonstrate that this approach significantly improves robustness against spurious correlations and out-of-distribution backgrounds. Code: https://github.com/ananthu-aniraj/ifam

Paper Structure

This paper contains 21 sections, 3 equations, 10 figures, 7 tables.

Figures (10)

  • Figure 1: Previous attention-based approaches apply the attention mask to a deep feature tensor, where all locations can be affected by the whole image due to large receptive fields (top). Our approach ensures that only the selected tokens contribute to the downstream task (bottom).
  • Figure 2: Left: iFAM first discovers task-relevant regions (Stage 1) and then classifies using only the selected regions (Stage 2), preventing reliance on background cues. Right: At test time, we leverage the model’s inherently faithful region attribution to design (training-free) intervention strategies that further enhance robustness to spurious correlations.
  • Figure 3: Leave-one-out (LOO) part removal intervention results on MetaShift (left) and SIIM-ACR (right) for $K=8$. The bottom right image shows a heatmap of the average pneumothorax occurrence across the dataset.
  • Figure 4: Qualitative results of part discovery of our model on the CUB dataset ($K=8$), along with results on the corresponding out-of-distribution (OOD) images from the WB200 (WaterBirds200) dataset and the effect of the test-time intervention of thresholding on the OOD images.
  • Figure 5: Qualitative results for part discovery for the iFAM model (without any *X) trained on the CUB dataset for different values of K, the number of foreground parts.
  • ...and 5 more figures