Table of Contents
Fetching ...

Pre-trained Language Models Improve the Few-shot Prompt Ability of Decision Transformer

Yu Yang, Pan Xu

TL;DR

The paper addresses the data-hungry prompt learning challenge in offline reinforcement learning by introducing LPDT, a framework that initializes Decision Transformers with pretrained language models, employs LoRA for parameter-efficient fine-tuning, and applies prompt regularization to distinguish tasks. By evaluating on MuJoCo and Meta-World, LPDT achieves competitive or superior performance to baselines with only 10% of the data in certain tasks, and extensive ablations validate the contribution of language-model priors, prompt regularization, and cross-architecture flexibility (e.g., Reinformer, EDT). The approach reduces data requirements and extends prompt-based multi-task capabilities, offering a practical path to scalable offline RL in safety-constrained settings. The framework also demonstrates versatility across DT variants and LM choices, with clear directions for scaling to larger language models and broader task families in future work.

Abstract

Decision Transformer (DT) has emerged as a promising class of algorithms in offline reinforcement learning (RL) tasks, leveraging pre-collected datasets and Transformer's capability to model long sequences. Recent works have demonstrated that using parts of trajectories from training tasks as prompts in DT enhances its performance on unseen tasks, giving rise to Prompt-DT methods. However, collecting data from specific environments can be both costly and unsafe in many scenarios, leading to suboptimal performance and limited few-shot prompt abilities due to the data-hungry nature of Transformer-based models. Additionally, the limited datasets used in pre-training make it challenging for Prompt-DT type of methods to distinguish between various RL tasks through prompts alone. To address these challenges, we introduce the Language model-initialized Prompt Decision Transformer (LPDT) framework, which leverages pretrained language models providing rich prior knowledge for RL tasks and fine-tunes the sequence model using Low-rank Adaptation (LoRA) for meta-RL problems. We further incorporate prompt regularization to effectively differentiate between tasks based on prompt feature representations. Comprehensive empirical studies demonstrate that initializing with a pre-trained language model provides the prior knowledge and achieves a similar performance with Prompt-DT under only $10\%$ data in some MuJoCo control tasks. We also provide a thorough ablation study to validate the effectiveness of each component, including sequence modeling, language models, prompt regularizations, and prompt strategies.

Pre-trained Language Models Improve the Few-shot Prompt Ability of Decision Transformer

TL;DR

The paper addresses the data-hungry prompt learning challenge in offline reinforcement learning by introducing LPDT, a framework that initializes Decision Transformers with pretrained language models, employs LoRA for parameter-efficient fine-tuning, and applies prompt regularization to distinguish tasks. By evaluating on MuJoCo and Meta-World, LPDT achieves competitive or superior performance to baselines with only 10% of the data in certain tasks, and extensive ablations validate the contribution of language-model priors, prompt regularization, and cross-architecture flexibility (e.g., Reinformer, EDT). The approach reduces data requirements and extends prompt-based multi-task capabilities, offering a practical path to scalable offline RL in safety-constrained settings. The framework also demonstrates versatility across DT variants and LM choices, with clear directions for scaling to larger language models and broader task families in future work.

Abstract

Decision Transformer (DT) has emerged as a promising class of algorithms in offline reinforcement learning (RL) tasks, leveraging pre-collected datasets and Transformer's capability to model long sequences. Recent works have demonstrated that using parts of trajectories from training tasks as prompts in DT enhances its performance on unseen tasks, giving rise to Prompt-DT methods. However, collecting data from specific environments can be both costly and unsafe in many scenarios, leading to suboptimal performance and limited few-shot prompt abilities due to the data-hungry nature of Transformer-based models. Additionally, the limited datasets used in pre-training make it challenging for Prompt-DT type of methods to distinguish between various RL tasks through prompts alone. To address these challenges, we introduce the Language model-initialized Prompt Decision Transformer (LPDT) framework, which leverages pretrained language models providing rich prior knowledge for RL tasks and fine-tunes the sequence model using Low-rank Adaptation (LoRA) for meta-RL problems. We further incorporate prompt regularization to effectively differentiate between tasks based on prompt feature representations. Comprehensive empirical studies demonstrate that initializing with a pre-trained language model provides the prior knowledge and achieves a similar performance with Prompt-DT under only data in some MuJoCo control tasks. We also provide a thorough ablation study to validate the effectiveness of each component, including sequence modeling, language models, prompt regularizations, and prompt strategies.
Paper Structure (27 sections, 12 equations, 2 figures, 10 tables, 1 algorithm)

This paper contains 27 sections, 12 equations, 2 figures, 10 tables, 1 algorithm.

Figures (2)

  • Figure 1: Overview of LPDT. We first initialize our algorithm using a pre-trained language model. The pre-trained language model is trained on a large text corpus using the causal language modeling objective, i.e., predicting the next token. Our method LPDT replaces the word embedding layers with linear layers, discarding the learned word features, to fully learn and capture the features of RL trajectory tokens. We fine-tune our model using parameter-efficient methods like Low-Rank Adaptation (LoRA). Specifically, we freeze the initial weights of the language model and update only the LoRA weights. The input to our approach consists of prompts accompanied by training trajectories from the same tasks. Unlike traditional models that predict word tokens, our method predicts action tokens in the RL trajectories. Additionally, we incorporate prompt regularization on the input prompts. This is achieved by introducing an additional loss on the prompt embeddings, which helps LPDT distinguish between different environments. More technical details of our method are presented in \ref{['sec:method']}.
  • Figure 2: Training curves on MuJoCo controls with three tasks: Cheetah-dir, Cheetah-vel and Ant-dir for Prompt-DT and our four methods LPDT-Classifier, LPDT-InforNCE, LPDT-Rein-Classifier and LPDT-Rein-InforNCE. The dataset we utilized is the full dataset. We plot the figures on one unseen task (task 1 for Cheetah-dir, task 49 for Cheetah-vel and task 48 for Ant-dir) with the average returns over 20 evaluation episodes. The figures demonstrate that our LPDT needs less sample data to achieve good performance and be more stable during the training.