Unsupervised Salient Patch Selection for Data-Efficient Reinforcement Learning
Zhaohui Jiang, Paul Weng
TL;DR
This paper addresses the data inefficiency of vision-based reinforcement learning by introducing SPIRL, a method that pretrains a compact Masked Autoencoder (MAE) to extract salient patches—defined as patches hard to reconstruct from neighbors—from input frames. The approach adaptively selects a variable number of patches per frame using a reconstruction-error map and a Lorenz-curve based criterion, then processes the patches with a Transformer-based RL module that can operate without convolutional features. Key contributions include fast MAE pretraining on modest data, a dynamic patch-count mechanism that avoids fixed K, a Transformer-based policy that aggregates variable patch sets, and interpretability via policy attention analyses. Experimental results on Atari demonstrate improved data efficiency in low-data regimes, with ablations confirming the value of salient patches and dynamic patch selection, and qualitative analyses providing insights into learned attention. SPIRL’s data-efficient, convolution-free design and its interpretability make it a promising approach for scalable, patch-focused representations in data-limited reinforcement learning tasks.
Abstract
To improve the sample efficiency of vision-based deep reinforcement learning (RL), we propose a novel method, called SPIRL, to automatically extract important patches from input images. Following Masked Auto-Encoders, SPIRL is based on Vision Transformer models pre-trained in a self-supervised fashion to reconstruct images from randomly-sampled patches. These pre-trained models can then be exploited to detect and select salient patches, defined as hard to reconstruct from neighboring patches. In RL, the SPIRL agent processes selected salient patches via an attention module. We empirically validate SPIRL on Atari games to test its data-efficiency against relevant state-of-the-art methods, including some traditional model-based methods and keypoint-based models. In addition, we analyze our model's interpretability capabilities.
