Table of Contents
Fetching ...

CausalCOMRL: Context-Based Offline Meta-Reinforcement Learning with Causal Representation

Zhengzhe Zhang, Wenjia Meng, Haoliang Sun, Gang Pan

TL;DR

CausalCOMRL is proposed, a context-based OMRL method that integrates causal representation learning and incorporates the causal relationships into task representations, enhancing the generalizability of RL agents and improving the distinction of task representations from different tasks by using mutual information optimization and contrastive learning.

Abstract

Context-based offline meta-reinforcement learning (OMRL) methods have achieved appealing success by leveraging pre-collected offline datasets to develop task representations that guide policy learning. However, current context-based OMRL methods often introduce spurious correlations, where task components are incorrectly correlated due to confounders. These correlations can degrade policy performance when the confounders in the test task differ from those in the training task. To address this problem, we propose CausalCOMRL, a context-based OMRL method that integrates causal representation learning. This approach uncovers causal relationships among the task components and incorporates the causal relationships into task representations, enhancing the generalizability of RL agents. We further improve the distinction of task representations from different tasks by using mutual information optimization and contrastive learning. Utilizing these causal task representations, we employ SAC to optimize policies on meta-RL benchmarks. Experimental results show that CausalCOMRL achieves better performance than other methods on most benchmarks.

CausalCOMRL: Context-Based Offline Meta-Reinforcement Learning with Causal Representation

TL;DR

CausalCOMRL is proposed, a context-based OMRL method that integrates causal representation learning and incorporates the causal relationships into task representations, enhancing the generalizability of RL agents and improving the distinction of task representations from different tasks by using mutual information optimization and contrastive learning.

Abstract

Context-based offline meta-reinforcement learning (OMRL) methods have achieved appealing success by leveraging pre-collected offline datasets to develop task representations that guide policy learning. However, current context-based OMRL methods often introduce spurious correlations, where task components are incorrectly correlated due to confounders. These correlations can degrade policy performance when the confounders in the test task differ from those in the training task. To address this problem, we propose CausalCOMRL, a context-based OMRL method that integrates causal representation learning. This approach uncovers causal relationships among the task components and incorporates the causal relationships into task representations, enhancing the generalizability of RL agents. We further improve the distinction of task representations from different tasks by using mutual information optimization and contrastive learning. Utilizing these causal task representations, we employ SAC to optimize policies on meta-RL benchmarks. Experimental results show that CausalCOMRL achieves better performance than other methods on most benchmarks.

Paper Structure

This paper contains 20 sections, 18 equations, 6 figures, 4 tables, 2 algorithms.

Figures (6)

  • Figure 1: Causal graph and encoder performance comparison. (a) Causal Graph Example. In the causal graph example, the nodes represent the state, action, and reward at varying timesteps, with edges indicating the causal relationships among them. $t$ represents the timestep. (b) t-SNE visualization of the task representation embedding vectors in Walker-Rand-Params.
  • Figure 2: Framework of CausalCOMRL: (a) Causal task encoder training module. (b) Meta-training module. (c) Meta-testing module.
  • Figure 3: Average test returns of CausalCOMRL against representative context-based OMRL methods on four environments in out-of-distribution tasks. The $X$-axis and $Y$-axis denote the timesteps and average return, respectively. The shaded region shows standard deviation across 5 seeds.
  • Figure 4: The t-SNE visualization of the task representation space in Half-Cheetah-Vel. Test task points are uniformly sampled from test tasks and color-coded from red to purple for velocities 0 to 3.
  • Figure 5: Average returns of CausalCOMRL against the baselines in the in-distribution training tasks. The shaded region shows standard deviation across 5 seeds.
  • ...and 1 more figures