Table of Contents
Fetching ...

MOOSS: Mask-Enhanced Temporal Contrastive Learning for Smooth State Evolution in Visual Reinforcement Learning

Jiarui Sun, M. Ugur Akcal, Wei Zhang, Girish Chowdhary

TL;DR

MOOSS addresses the challenge of sample efficiency in visual reinforcement learning by explicitly modeling state evolution through a graph-based spatial-temporal masking regime and a multi-level temporal contrastive objective. It employs dual encoders (one momentum) and a Transformer-based predictive decoder to generate query states from masked inputs, while a temporal contrastive loss enforces smooth, hierarchical state similarities across time. The approach yields significant improvements on DMControl and Atari benchmarks, outperforming strong baselines and providing thorough ablations that highlight the benefits of both masking and multi-level contrastive learning. The method advances state representation learning in visual RL and offers practical gains in sample efficiency with open-source code available for reproducibility and further exploration.

Abstract

In visual Reinforcement Learning (RL), learning from pixel-based observations poses significant challenges on sample efficiency, primarily due to the complexity of extracting informative state representations from high-dimensional data. Previous methods such as contrastive-based approaches have made strides in improving sample efficiency but fall short in modeling the nuanced evolution of states. To address this, we introduce MOOSS, a novel framework that leverages a temporal contrastive objective with the help of graph-based spatial-temporal masking to explicitly model state evolution in visual RL. Specifically, we propose a self-supervised dual-component strategy that integrates (1) a graph construction of pixel-based observations for spatial-temporal masking, coupled with (2) a multi-level contrastive learning mechanism that enriches state representations by emphasizing temporal continuity and change of states. MOOSS advances the understanding of state dynamics by disrupting and learning from spatial-temporal correlations, which facilitates policy learning. Our comprehensive evaluation on multiple continuous and discrete control benchmarks shows that MOOSS outperforms previous state-of-the-art visual RL methods in terms of sample efficiency, demonstrating the effectiveness of our method. Our code is released at https://github.com/jsun57/MOOSS.

MOOSS: Mask-Enhanced Temporal Contrastive Learning for Smooth State Evolution in Visual Reinforcement Learning

TL;DR

MOOSS addresses the challenge of sample efficiency in visual reinforcement learning by explicitly modeling state evolution through a graph-based spatial-temporal masking regime and a multi-level temporal contrastive objective. It employs dual encoders (one momentum) and a Transformer-based predictive decoder to generate query states from masked inputs, while a temporal contrastive loss enforces smooth, hierarchical state similarities across time. The approach yields significant improvements on DMControl and Atari benchmarks, outperforming strong baselines and providing thorough ablations that highlight the benefits of both masking and multi-level contrastive learning. The method advances state representation learning in visual RL and offers practical gains in sample efficiency with open-source code available for reproducibility and further exploration.

Abstract

In visual Reinforcement Learning (RL), learning from pixel-based observations poses significant challenges on sample efficiency, primarily due to the complexity of extracting informative state representations from high-dimensional data. Previous methods such as contrastive-based approaches have made strides in improving sample efficiency but fall short in modeling the nuanced evolution of states. To address this, we introduce MOOSS, a novel framework that leverages a temporal contrastive objective with the help of graph-based spatial-temporal masking to explicitly model state evolution in visual RL. Specifically, we propose a self-supervised dual-component strategy that integrates (1) a graph construction of pixel-based observations for spatial-temporal masking, coupled with (2) a multi-level contrastive learning mechanism that enriches state representations by emphasizing temporal continuity and change of states. MOOSS advances the understanding of state dynamics by disrupting and learning from spatial-temporal correlations, which facilitates policy learning. Our comprehensive evaluation on multiple continuous and discrete control benchmarks shows that MOOSS outperforms previous state-of-the-art visual RL methods in terms of sample efficiency, demonstrating the effectiveness of our method. Our code is released at https://github.com/jsun57/MOOSS.
Paper Structure (36 sections, 12 equations, 7 figures, 8 tables)

This paper contains 36 sections, 12 equations, 7 figures, 8 tables.

Figures (7)

  • Figure 1: t-SNE van2008visualizing visualization of the state representations from a trained visual RL agent on the reacher-easy task from DeepMind Control Suite tassa2018deepmind. The state representations are encoded from an observation sequence ${\mathbf{o}}_{0:19}$ of length 20, guided by random actions. Numbers within the color-coded dots denote the temporal indices. Note that the t-SNE visualization demonstrates a temporal order, suggesting a gradual, smooth evolution of the states.
  • Figure 2: The proposed MOOSS framework. We first perform graph-based spatial-temporal masking on the observation sequence ${\mathbf{o}}_{{t}:{t+F-1}}$. The masked observations are then fed into a query encoder, generating ${\tilde{{\mathbf{s}}}}_i$s. The unmasked observations are processed by a momentum key encoder. The key encoder generates the key state embeddings$\bar{{\mathbf{s}}}_{{t}:{t+F-1}}$. A predictive decoder is used to further process the outputs $\tilde{{\mathbf{s}}}_i$s of the query encoder, generating the query state embeddings$\hat{{\mathbf{s}}}_{{t}:{t+F-1}}$ conditioned on the corresponding action embeddings ${{\mathbf{a}}}_i$s (Embs).
  • Figure 3: Illustration of our graph-based spatial-temporal masking. The observation sequence $\eta_o$ with shape $F \times H \times W$ is equally divided into non-overlapping cubes with shape $f \times h \times w$, constructing a spatial-temporal graph ${\mathcal{G}}$ with adjacent nodes connected. Masking is applied by simulating a random walk on the constructed graph.
  • Figure 4: Illustration of the temporal contrastive objective. This mock setup contains $3$ sampled sequences with $15$ query-key pairs in total (observation length is $F=5$; batch size is $3$), and models four similarity levels with $L=3$. If embeddings are learned from the same sequence, they share the same color scheme. The temporal contrastive objective aims to capture a ranked order of state similarities, indicated by the diminishing color intensity from the main diagonal to the off-diagonal cells. In this example, $\mathrm{\Phi} = \mathrm{sim({\mathbf{q}}_1, {\mathbf{k}}_4)} = \mathrm{sim({\mathbf{q}}, {\mathbf{k}}_{\Delta=3})}$, and $\mathrm{\Omega} = \mathrm{sim({\mathbf{q}}_{14}, {\mathbf{k}}_{12})} = \mathrm{sim({\mathbf{q}}, {\mathbf{k}}_{\Delta=2})}$. The gray cells denote learned similar scores between ${\mathbf{q}}$ and ${\mathbf{k}}'$, i.e., query-key pairs either belonging to different sampled sequences, or have temporal distance larger than 3. These pairs belong to the lowest similarity level.
  • Figure A.1: Ablation on window size $L$ and masking ratio $p_m$.
  • ...and 2 more figures