Table of Contents
Fetching ...

Disentangling Policy from Offline Task Representation Learning via Adversarial Data Augmentation

Chengxing Jia, Fuxiang Zhang, Yi-Chen Li, Chen-Xiao Gao, Xu-Hui Liu, Lei Yuan, Zongzhang Zhang, Yang Yu

TL;DR

A novel algorithm is introduced to disentangle the impact of behavior policy from task representation learning through a process called adversarial data augmentation, which aims to create adversarial examples designed to confound learned task representations and lead to incorrect task identification.

Abstract

Offline meta-reinforcement learning (OMRL) proficiently allows an agent to tackle novel tasks while solely relying on a static dataset. For precise and efficient task identification, existing OMRL research suggests learning separate task representations that be incorporated with policy input, thus forming a context-based meta-policy. A major approach to train task representations is to adopt contrastive learning using multi-task offline data. The dataset typically encompasses interactions from various policies (i.e., the behavior policies), thus providing a plethora of contextual information regarding different tasks. Nonetheless, amassing data from a substantial number of policies is not only impractical but also often unattainable in realistic settings. Instead, we resort to a more constrained yet practical scenario, where multi-task data collection occurs with a limited number of policies. We observed that learned task representations from previous OMRL methods tend to correlate spuriously with the behavior policy instead of reflecting the essential characteristics of the task, resulting in unfavorable out-of-distribution generalization. To alleviate this issue, we introduce a novel algorithm to disentangle the impact of behavior policy from task representation learning through a process called adversarial data augmentation. Specifically, the objective of adversarial data augmentation is not merely to generate data analogous to offline data distribution; instead, it aims to create adversarial examples designed to confound learned task representations and lead to incorrect task identification. Our experiments show that learning from such adversarial samples significantly enhances the robustness and effectiveness of the task identification process and realizes satisfactory out-of-distribution generalization.

Disentangling Policy from Offline Task Representation Learning via Adversarial Data Augmentation

TL;DR

A novel algorithm is introduced to disentangle the impact of behavior policy from task representation learning through a process called adversarial data augmentation, which aims to create adversarial examples designed to confound learned task representations and lead to incorrect task identification.

Abstract

Offline meta-reinforcement learning (OMRL) proficiently allows an agent to tackle novel tasks while solely relying on a static dataset. For precise and efficient task identification, existing OMRL research suggests learning separate task representations that be incorporated with policy input, thus forming a context-based meta-policy. A major approach to train task representations is to adopt contrastive learning using multi-task offline data. The dataset typically encompasses interactions from various policies (i.e., the behavior policies), thus providing a plethora of contextual information regarding different tasks. Nonetheless, amassing data from a substantial number of policies is not only impractical but also often unattainable in realistic settings. Instead, we resort to a more constrained yet practical scenario, where multi-task data collection occurs with a limited number of policies. We observed that learned task representations from previous OMRL methods tend to correlate spuriously with the behavior policy instead of reflecting the essential characteristics of the task, resulting in unfavorable out-of-distribution generalization. To alleviate this issue, we introduce a novel algorithm to disentangle the impact of behavior policy from task representation learning through a process called adversarial data augmentation. Specifically, the objective of adversarial data augmentation is not merely to generate data analogous to offline data distribution; instead, it aims to create adversarial examples designed to confound learned task representations and lead to incorrect task identification. Our experiments show that learning from such adversarial samples significantly enhances the robustness and effectiveness of the task identification process and realizes satisfactory out-of-distribution generalization.
Paper Structure (27 sections, 1 theorem, 12 equations, 8 figures, 5 tables, 1 algorithm)

This paper contains 27 sections, 1 theorem, 12 equations, 8 figures, 5 tables, 1 algorithm.

Key Result

Theorem A.3

(Task Distribution Shift). Given the probability space of meta-task distribution $(\mathcal{M}, 2^{\mathcal{M}}, P)$, for each model $M$ sampled according to $P$, we learned surrogate model $\hat{M}$ and let $\hat{P}$ be the task distribution associated with the real model $\hat{M}$, the TV divergen

Figures (8)

  • Figure 1: The overall process of using adversarial data augmentation for offline meta-RL.
  • Figure 2: (a) Performance on InvertedPendulum-v2 with 1.0x gravity coefficient. (b) The relative representation metric (from Equation \ref{['eq:relative-metric']}) of different methods.
  • Figure 3: Visualization on task representations with t-SNE dimensionality reduction for (a) ReDA, (b) OM-SAC, (c) FOCAL, and (d) CORRO on the task set Walker2d Dof-Damping-1. Task representations from different tasks are shown in distinct colors.
  • Figure 4: The visual trajectories of Walker2d-v2 with distinct hyper-parameters of dof-damping, including $0.5$, $1.0$, and $1.5$. Each trajectory is fed with the same initial state and following action sequence.
  • Figure 5: Visualization on the training and test datasets in InvertedPendulum-v2 from our didactic example. (a) The rendered images from the environment in different datasets. (b) The t-SNE plot of state-action pairs in different datasets. (c) The relative representation metrics (from Equation\ref{['eq:relative-metric']}) of two training datasets compared to the test dataset.
  • ...and 3 more figures

Theorems & Definitions (2)

  • Definition A.1
  • Theorem A.3