Table of Contents
Fetching ...

Contrastive Behavioral Similarity Embeddings for Generalization in Reinforcement Learning

Rishabh Agarwal, Marlos C. Machado, Pablo Samuel Castro, Marc G. Bellemare

TL;DR

This work addresses the challenge of generalization in reinforcement learning by exploiting the sequential structure of RL through a policy-centered state similarity metric (PSM) and a contrastive learning framework to produce policy similarity embeddings (PSEs). By combining a theoretically grounded PSM with CMEs, the approach explicitly encodes invariances in optimal behavior across related environments, providing a bound on transfer suboptimality and improving zero-shot generalization across diverse benchmarks. Empirically, PSEs outperform standard regularization and bisimulation-based baselines on a pixel-based Jumping Task, LQR with distractors, and the Distracting DM Control Suite, and are robust to suboptimal policies and task variations. The results suggest that aligning representations with behavioral similarity is a powerful, orthogonal contribution to existing data augmentation and domain-generalization techniques in RL, with practical benefits for generalization in complex, high-dimensional tasks.

Abstract

Reinforcement learning methods trained on few environments rarely learn policies that generalize to unseen environments. To improve generalization, we incorporate the inherent sequential structure in reinforcement learning into the representation learning process. This approach is orthogonal to recent approaches, which rarely exploit this structure explicitly. Specifically, we introduce a theoretically motivated policy similarity metric (PSM) for measuring behavioral similarity between states. PSM assigns high similarity to states for which the optimal policies in those states as well as in future states are similar. We also present a contrastive representation learning procedure to embed any state similarity metric, which we instantiate with PSM to obtain policy similarity embeddings (PSEs). We demonstrate that PSEs improve generalization on diverse benchmarks, including LQR with spurious correlations, a jumping task from pixels, and Distracting DM Control Suite.

Contrastive Behavioral Similarity Embeddings for Generalization in Reinforcement Learning

TL;DR

This work addresses the challenge of generalization in reinforcement learning by exploiting the sequential structure of RL through a policy-centered state similarity metric (PSM) and a contrastive learning framework to produce policy similarity embeddings (PSEs). By combining a theoretically grounded PSM with CMEs, the approach explicitly encodes invariances in optimal behavior across related environments, providing a bound on transfer suboptimality and improving zero-shot generalization across diverse benchmarks. Empirically, PSEs outperform standard regularization and bisimulation-based baselines on a pixel-based Jumping Task, LQR with distractors, and the Distracting DM Control Suite, and are robust to suboptimal policies and task variations. The results suggest that aligning representations with behavioral similarity is a powerful, orthogonal contribution to existing data augmentation and domain-generalization techniques in RL, with practical benefits for generalization in complex, high-dimensional tasks.

Abstract

Reinforcement learning methods trained on few environments rarely learn policies that generalize to unseen environments. To improve generalization, we incorporate the inherent sequential structure in reinforcement learning into the representation learning process. This approach is orthogonal to recent approaches, which rarely exploit this structure explicitly. Specifically, we introduce a theoretically motivated policy similarity metric (PSM) for measuring behavioral similarity between states. PSM assigns high similarity to states for which the optimal policies in those states as well as in future states are similar. We also present a contrastive representation learning procedure to embed any state similarity metric, which we instantiate with PSM to obtain policy similarity embeddings (PSEs). We demonstrate that PSEs improve generalization on diverse benchmarks, including LQR with spurious correlations, a jumping task from pixels, and Distracting DM Control Suite.

Paper Structure

This paper contains 36 sections, 9 theorems, 34 equations, 20 figures, 13 tables, 1 algorithm.

Key Result

theorem 1

[Bound on policy transfer] For any $y\in {\mathcal{Y}}$, let $Y^t_{y}\sim P^{\tilde{\pi}}(\cdot \,|\, Y^{t-1}_y)$ define the sequence of random states encountered starting in $Y^0_y=y$ and following policy $\tilde{\pi}$. We have:

Figures (20)

  • Figure 1: Jumping task: The agent (white block), learning from pixels, needs to jump over an obstacle (grey square). The challenge is to generalize to unseen obstacle positions and floor heights in test tasks using a small number of training tasks. We show the agent's trajectories using faded blocks.
  • Figure 2: Cyan edges represent actions with a positive reward, which are also the optimal actions. Zero rewards everywhere else. $x_0, y_0$ are the start states while $x_2, y_2$ are the terminal states.
  • Figure 3: Architecture for learning CMEs. Given an input pair $(x, y)$, we first apply the (optional) data augmentation operator $\Psi$ to produce the input augmentations$\Psi_x := \Psi(x), \Psi_y := \Psi(y)$. When not using data augmentation, $\Psi$ is equal to the identity operator, that is, $\forall x\ \Psi(x) = x$. The agent's policy network then outputs the representations for these augmentations by applying the encoder $f_\theta$, that is, $f_x = f_\theta(\Psi_x),\ f_y = f_\theta(\Psi_y)$. These representations are projected using a non-linear projector $h_\theta$ to obtain the embedding $z_{\theta}$, that is, $z_\theta(x) = h_\theta(f_x),\ z_\theta(y) = h_\theta(f_y)$. These metric embeddings are trained using the contrastive loss defined in Equation (4). The policy $\pi_\theta$ is an affine function of the representation, that is, $\pi_\theta(\cdot | y) = W^{T}f_y + b$, where $W, b$ are learned weights and biases. The entire network is trained end-to-end jointly using the reinforcement learning (or imitation learning) loss in conjunction with the auxiliary contrastive loss.
  • Figure 4: Jumping Task: Visualization of average performance of PSEs with data augmentation across different configurations. We plot the median performance across 100 runs. Each tile in the grid represents a different task (obstacle position/floor height combination). For each grid configuration, the height varies along the $y$-axis (11 heights) while the obstacle position varies along the $x$-axis (26 locations). The red letter $\top$ indicates the training tasks. Random grid depicts only one instance, each run consisted of a different test/train split. Beige tiles are tasks PSEs solved while black tiles are tasks PSEs did not solve when used with data augmentation. These results were chosen across all the 100 runs to demonstrate what the average reported performance looks like.
  • Figure 5: Embedding visualization. (a) Optimal trajectories on original jumping task (visualized as coloured blocks) with different obstacle positions. We visualize the hidden representations using UMAP, where the color of points indicate the tasks of the corresponding observations. Points with the same number label correspond to same distance of the agent from the obstacle, the underlying optimal invariant feature across tasks.
  • ...and 15 more figures

Theorems & Definitions (17)

  • theorem 1
  • lemma 1
  • proof
  • lemma 2
  • proof
  • lemma 3
  • proof
  • proposition A.1
  • proof
  • theorem 1
  • ...and 7 more