Table of Contents
Fetching ...

Unsupervised Salient Patch Selection for Data-Efficient Reinforcement Learning

Zhaohui Jiang, Paul Weng

TL;DR

This paper addresses the data inefficiency of vision-based reinforcement learning by introducing SPIRL, a method that pretrains a compact Masked Autoencoder (MAE) to extract salient patches—defined as patches hard to reconstruct from neighbors—from input frames. The approach adaptively selects a variable number of patches per frame using a reconstruction-error map and a Lorenz-curve based criterion, then processes the patches with a Transformer-based RL module that can operate without convolutional features. Key contributions include fast MAE pretraining on modest data, a dynamic patch-count mechanism that avoids fixed K, a Transformer-based policy that aggregates variable patch sets, and interpretability via policy attention analyses. Experimental results on Atari demonstrate improved data efficiency in low-data regimes, with ablations confirming the value of salient patches and dynamic patch selection, and qualitative analyses providing insights into learned attention. SPIRL’s data-efficient, convolution-free design and its interpretability make it a promising approach for scalable, patch-focused representations in data-limited reinforcement learning tasks.

Abstract

To improve the sample efficiency of vision-based deep reinforcement learning (RL), we propose a novel method, called SPIRL, to automatically extract important patches from input images. Following Masked Auto-Encoders, SPIRL is based on Vision Transformer models pre-trained in a self-supervised fashion to reconstruct images from randomly-sampled patches. These pre-trained models can then be exploited to detect and select salient patches, defined as hard to reconstruct from neighboring patches. In RL, the SPIRL agent processes selected salient patches via an attention module. We empirically validate SPIRL on Atari games to test its data-efficiency against relevant state-of-the-art methods, including some traditional model-based methods and keypoint-based models. In addition, we analyze our model's interpretability capabilities.

Unsupervised Salient Patch Selection for Data-Efficient Reinforcement Learning

TL;DR

This paper addresses the data inefficiency of vision-based reinforcement learning by introducing SPIRL, a method that pretrains a compact Masked Autoencoder (MAE) to extract salient patches—defined as patches hard to reconstruct from neighbors—from input frames. The approach adaptively selects a variable number of patches per frame using a reconstruction-error map and a Lorenz-curve based criterion, then processes the patches with a Transformer-based RL module that can operate without convolutional features. Key contributions include fast MAE pretraining on modest data, a dynamic patch-count mechanism that avoids fixed K, a Transformer-based policy that aggregates variable patch sets, and interpretability via policy attention analyses. Experimental results on Atari demonstrate improved data efficiency in low-data regimes, with ablations confirming the value of salient patches and dynamic patch selection, and qualitative analyses providing insights into learned attention. SPIRL’s data-efficient, convolution-free design and its interpretability make it a promising approach for scalable, patch-focused representations in data-limited reinforcement learning tasks.

Abstract

To improve the sample efficiency of vision-based deep reinforcement learning (RL), we propose a novel method, called SPIRL, to automatically extract important patches from input images. Following Masked Auto-Encoders, SPIRL is based on Vision Transformer models pre-trained in a self-supervised fashion to reconstruct images from randomly-sampled patches. These pre-trained models can then be exploited to detect and select salient patches, defined as hard to reconstruct from neighboring patches. In RL, the SPIRL agent processes selected salient patches via an attention module. We empirically validate SPIRL on Atari games to test its data-efficiency against relevant state-of-the-art methods, including some traditional model-based methods and keypoint-based models. In addition, we analyze our model's interpretability capabilities.
Paper Structure (36 sections, 7 equations, 8 figures, 15 tables)

This paper contains 36 sections, 7 equations, 8 figures, 15 tables.

Figures (8)

  • Figure 1: Adaptation of MAE illustrated on Seaquest: a higher-capacity decoder can extract the background information, while a smaller encoder focuses on embedding only salient elements in images. Overall, its size is more than 50$\times$ smaller than in he2022MAE.
  • Figure 2: Visualization from the pre-trained MAE model. The first 3 columns are visualizations related to a certain frame: (1$^{st}$ column) Original frames; (2$^{nd}$) Reconstruction from patch surroundings; (3$^{rd}$) Reconstruction error maps where brighter colors indicate larger values. The last 3 columns are obtained without frame patches, but with different inputs to the pre-trained MAE decoder: Reconstructions with (4$^{th}$) $\{PE_{i,j}\}$ added to $[mask]$ token; (5$^{th}$) only $\{PE_{i,j}\}$; (6$^{th}$) only $[mask]$ token.
  • Figure 3: Different salient patch selection strategies in three different frames (one per column) of Seaquest: (1$^{st}$ row) blue lines (resp. red dashed lines) are cumulative sum of errors in ${\bm E}$ (resp. $p^*$); (2$^{nd}$) selected patches (red squares) with $K$ determined by $p^*$ from 1$^{st}$ row; (3$^{rd}$ & 4$^{th}$) selected patches with pre-defined $K$.
  • Figure 4: Comparison of key-point (represented as white crosses) / salient-patch (transparent red patches) selection using Transporter, PermaKey, or SPIRL. Each row corresponds to a same frame of a given game. Visualization for other frames can be found in Appendix A.4.
  • Figure 5: Overview of SPIRL: (1) Pre-trained MAE rebuilds each patch from its surrounding, (2) a reconstruction error map is as MSE between pairs of reconstructed and target patches, (3) most salient patches are selected according to accumulated errors, (4) their embeddings are obtained from MAE encoder and concatenated to previous embeddings, and (5) A Transformer-based network aggregates them from which an MLP computes the estimated $Q$-values.
  • ...and 3 more figures