Table of Contents
Fetching ...

Prompting Decision Transformer for Few-Shot Policy Generalization

Mengdi Xu, Yikang Shen, Shun Zhang, Yuchen Lu, Ding Zhao, Joshua B. Tenenbaum, Chuang Gan

TL;DR

The paper addresses offline RL generalization to unseen tasks by leveraging a prompt-based inductive bias. It introduces Prompt-based Decision Transformer (Prompt-DT), which uses short trajectory prompts of length $K^\ Star$ to condition a Transformer-based policy and achieve few-shot generalization without finetuning. In five MuJoCo benchmarks, Prompt-DT outperforms its variants and a strong offline meta-RL baseline MACAW, and remains robust to prompt length while enabling extrapolation to out-of-distribution tasks. The results suggest that trajectory prompts can encode task information effectively and motivate broader use of prompt-based architectures for data-efficient RL.

Abstract

Humans can leverage prior experience and learn novel tasks from a handful of demonstrations. In contrast to offline meta-reinforcement learning, which aims to achieve quick adaptation through better algorithm design, we investigate the effect of architecture inductive bias on the few-shot learning capability. We propose a Prompt-based Decision Transformer (Prompt-DT), which leverages the sequential modeling ability of the Transformer architecture and the prompt framework to achieve few-shot adaptation in offline RL. We design the trajectory prompt, which contains segments of the few-shot demonstrations, and encodes task-specific information to guide policy generation. Our experiments in five MuJoCo control benchmarks show that Prompt-DT is a strong few-shot learner without any extra finetuning on unseen target tasks. Prompt-DT outperforms its variants and strong meta offline RL baselines by a large margin with a trajectory prompt containing only a few timesteps. Prompt-DT is also robust to prompt length changes and can generalize to out-of-distribution (OOD) environments.

Prompting Decision Transformer for Few-Shot Policy Generalization

TL;DR

The paper addresses offline RL generalization to unseen tasks by leveraging a prompt-based inductive bias. It introduces Prompt-based Decision Transformer (Prompt-DT), which uses short trajectory prompts of length to condition a Transformer-based policy and achieve few-shot generalization without finetuning. In five MuJoCo benchmarks, Prompt-DT outperforms its variants and a strong offline meta-RL baseline MACAW, and remains robust to prompt length while enabling extrapolation to out-of-distribution tasks. The results suggest that trajectory prompts can encode task information effectively and motivate broader use of prompt-based architectures for data-efficient RL.

Abstract

Humans can leverage prior experience and learn novel tasks from a handful of demonstrations. In contrast to offline meta-reinforcement learning, which aims to achieve quick adaptation through better algorithm design, we investigate the effect of architecture inductive bias on the few-shot learning capability. We propose a Prompt-based Decision Transformer (Prompt-DT), which leverages the sequential modeling ability of the Transformer architecture and the prompt framework to achieve few-shot adaptation in offline RL. We design the trajectory prompt, which contains segments of the few-shot demonstrations, and encodes task-specific information to guide policy generation. Our experiments in five MuJoCo control benchmarks show that Prompt-DT is a strong few-shot learner without any extra finetuning on unseen target tasks. Prompt-DT outperforms its variants and strong meta offline RL baselines by a large margin with a trajectory prompt containing only a few timesteps. Prompt-DT is also robust to prompt length changes and can generalize to out-of-distribution (OOD) environments.
Paper Structure (33 sections, 3 equations, 8 figures, 7 tables, 3 algorithms)

This paper contains 33 sections, 3 equations, 8 figures, 7 tables, 3 algorithms.

Figures (8)

  • Figure 1: Prompt-DT for few-shot policy generalization. The left shows the few-shot demonstration dataset ${\mathcal{P}}_i$ for each task ${\mathcal{T}}_i \in {\mathcal{T}}^{train} \cup {\mathcal{T}}^{test}$. The trajectory prompt is defined as a trajectory sequence of length $K^{\star}$ sampled from various episodes stored in ${\mathcal{P}}_i$. In both pretraining and few-shot evaluation, Prompt-DT takes both the trajectory prompt augmentation and the most recent $K$-step history as input, and autoregressively outputs actions corresponding to each state in the input sequence.
  • Figure 2: Episodic accumulated returns in never-before-seen tasks of Prompt-DT, Prompt-based Behavior Cloning (Prompt-MT-BC), Multi-task Offline RL (MT-ORL), Multi-task Behavior Cloning (MT-BC-Finetune), and Meta-Actor Critic with Advantage Weighting (MACAW). All methods are trained with the same expert dataset ${\mathcal{D}}$. Each plot is run with three seeds. Prompt-DT and Prompt-MT-BC have a few-shot dataset ${\mathcal{P}}$ containing expert demonstrations. Cheetah-dir, Cheetah-vel and Ant-dir have prompts of length $K^{\star}=5$. Dial has prompts of length $K^{\star}=15$. Meta-World reach-v2 has prompts of length $K^{\star}=2$. MT-BC-Finetune and MACAW use the same amount of data which equals the prompt length for finetuning at testing time. The dashed lines show the optimal performance of MACAW reported in mitchell2021offline. Prompt-augmented methods including Prompt-DT and Prompt-MT-BC outperform baselines across environments with a short trajectory prompt.
  • Figure 3: Ablation: The effect of prompt quality to Prompt-DT's few-shot generalization ability. We train Prompt-DT with datasets and demonstrations with the same quality in each plot. The left, middle and right figure corresponds to expert, medium, and random dataset collected in Cheetah-vel. Each plot is run with 2 seeds. We feed Prompt-DT trajectory prompts of different qualities when testing Prompt-DT's few-shot generalization ability. The results show that Prompt-DT tries to generate policies that match the prompt quality, and the quality of training datasets also affects the few-shot generalization ability of Prompt-DT.
  • Figure 4: Episodic accumulated returns in novel tasks with goals out of training tasks' goal range in Ant-dir. Each plot is run with 3 seeds. Prompt-MT-BC scores the highest reward and outperforms baselines without trajectory prompts by a large margin.
  • Figure 5: Ablation: The effect of trajectory prompt length to Prompt-DT's performance. Each plot is run with 3 seeds.
  • ...and 3 more figures