Table of Contents
Fetching ...

QT-TDM: Planning With Transformer Dynamics Model and Autoregressive Q-Learning

Mostafa Kotb, Cornelius Weber, Muhammad Burhan Hafez, Stefan Wermter

TL;DR

The proposed method, QT-TDM, integrates the robust predictive capabilities of Transformers as dynamics models with the efficacy of a model-free Q-Transformer to mitigate the computational burden associated with real-time planning.

Abstract

Inspired by the success of the Transformer architecture in natural language processing and computer vision, we investigate the use of Transformers in Reinforcement Learning (RL), specifically in modeling the environment's dynamics using Transformer Dynamics Models (TDMs). We evaluate the capabilities of TDMs for continuous control in real-time planning scenarios with Model Predictive Control (MPC). While Transformers excel in long-horizon prediction, their tokenization mechanism and autoregressive nature lead to costly planning over long horizons, especially as the environment's dimensionality increases. To alleviate this issue, we use a TDM for short-term planning, and learn an autoregressive discrete Q-function using a separate Q-Transformer (QT) model to estimate a long-term return beyond the short-horizon planning. Our proposed method, QT-TDM, integrates the robust predictive capabilities of Transformers as dynamics models with the efficacy of a model-free Q-Transformer to mitigate the computational burden associated with real-time planning. Experiments in diverse state-based continuous control tasks show that QT-TDM is superior in performance and sample efficiency compared to existing Transformer-based RL models while achieving fast and computationally efficient inference.

QT-TDM: Planning With Transformer Dynamics Model and Autoregressive Q-Learning

TL;DR

The proposed method, QT-TDM, integrates the robust predictive capabilities of Transformers as dynamics models with the efficacy of a model-free Q-Transformer to mitigate the computational burden associated with real-time planning.

Abstract

Inspired by the success of the Transformer architecture in natural language processing and computer vision, we investigate the use of Transformers in Reinforcement Learning (RL), specifically in modeling the environment's dynamics using Transformer Dynamics Models (TDMs). We evaluate the capabilities of TDMs for continuous control in real-time planning scenarios with Model Predictive Control (MPC). While Transformers excel in long-horizon prediction, their tokenization mechanism and autoregressive nature lead to costly planning over long horizons, especially as the environment's dimensionality increases. To alleviate this issue, we use a TDM for short-term planning, and learn an autoregressive discrete Q-function using a separate Q-Transformer (QT) model to estimate a long-term return beyond the short-horizon planning. Our proposed method, QT-TDM, integrates the robust predictive capabilities of Transformers as dynamics models with the efficacy of a model-free Q-Transformer to mitigate the computational burden associated with real-time planning. Experiments in diverse state-based continuous control tasks show that QT-TDM is superior in performance and sample efficiency compared to existing Transformer-based RL models while achieving fast and computationally efficient inference.
Paper Structure (25 sections, 9 equations, 5 figures, 5 tables, 2 algorithms)

This paper contains 25 sections, 9 equations, 5 figures, 5 tables, 2 algorithms.

Figures (5)

  • Figure 1: QT-TDM Inference: The learned TDM model plans for short planning horizon $H$, while the learned QT model estimates an autoregressive terminal value $Q^i_H$ for each action dimension $a^i_H$ which guides the planning beyond the myopic horizon.
  • Figure 2: QT-TDM Architecture, which consists of two modules: (a) TDM and (b) QT. Both modules have a GPT-like Transformer as a main component and share the same tokenization scheme. The state $s_t$ is tokenized into a single token using a learned linear layer. A per-dimension tokenization is performed for the $N$-dimensional action by discretizing each dimension independently into $K$ bins, then using an embedding table. The TDM module predicts the next state $\hat{s}_{t+1}$ and the reward $\hat{r}_t$ and is trained on $\mathit{L}$ sampled time steps (for brevity, we only show two time steps). The QT module predicts a Q-value for each action dimension $\hat{q}_t^{i,1:K}\space\forall i \in [1,...,N]$.
  • Figure 3: Continuous Control Tasks. Two locomotion tasks with high-dimensional action space (Walker and Cheetah) and one sparse reward task (Reacher) from DMC dmc. Six robotic manipulation tasks (d)-(i) with various challenges from MetaWorld metaworld.
  • Figure 4: Learning curves. Three tasks from DMC (top row), episode return as performance metric. Six tasks from MetaWorld (middle and bottom rows), success rate (%) as performance metric. Mean over 3 seeds; shaded areas are standard deviations. For DreamerV3, we report the final performance from tdmpc2.
  • Figure 5: QT-TDM with different planning horizon $(H)$ on Cheetah Run task.