Table of Contents
Fetching ...

Deep Transformer Q-Networks for Partially Observable Reinforcement Learning

Kevin Esslinger, Robert Platt, Christopher Amato

TL;DR

The paper tackles partial observability in reinforcement learning by introducing DTQN, a transformer decoder-based Q-network that encodes an agent's history through self-attention and learns positional encodings. It trains with an intermediate Q-value prediction objective, enabling supervision from Q-values across all timesteps in the history and improving learning stability. Across multiple POMDP-like domains, DTQN outperforms or matches strong baselines (DRQN, DQN, ATTN) with faster learning and higher final performance, while also providing interpretable attention visualizations. The work demonstrates the viability of transformer-based history models for partial observability and provides a modular implementation to serve as a benchmark for future transformer-based RL methods.

Abstract

Real-world reinforcement learning tasks often involve some form of partial observability where the observations only give a partial or noisy view of the true state of the world. Such tasks typically require some form of memory, where the agent has access to multiple past observations, in order to perform well. One popular way to incorporate memory is by using a recurrent neural network to access the agent's history. However, recurrent neural networks in reinforcement learning are often fragile and difficult to train, susceptible to catastrophic forgetting and sometimes fail completely as a result. In this work, we propose Deep Transformer Q-Networks (DTQN), a novel architecture utilizing transformers and self-attention to encode an agent's history. DTQN is designed modularly, and we compare results against several modifications to our base model. Our experiments demonstrate the transformer can solve partially observable tasks faster and more stably than previous recurrent approaches.

Deep Transformer Q-Networks for Partially Observable Reinforcement Learning

TL;DR

The paper tackles partial observability in reinforcement learning by introducing DTQN, a transformer decoder-based Q-network that encodes an agent's history through self-attention and learns positional encodings. It trains with an intermediate Q-value prediction objective, enabling supervision from Q-values across all timesteps in the history and improving learning stability. Across multiple POMDP-like domains, DTQN outperforms or matches strong baselines (DRQN, DQN, ATTN) with faster learning and higher final performance, while also providing interpretable attention visualizations. The work demonstrates the viability of transformer-based history models for partial observability and provides a modular implementation to serve as a benchmark for future transformer-based RL methods.

Abstract

Real-world reinforcement learning tasks often involve some form of partial observability where the observations only give a partial or noisy view of the true state of the world. Such tasks typically require some form of memory, where the agent has access to multiple past observations, in order to perform well. One popular way to incorporate memory is by using a recurrent neural network to access the agent's history. However, recurrent neural networks in reinforcement learning are often fragile and difficult to train, susceptible to catastrophic forgetting and sometimes fail completely as a result. In this work, we propose Deep Transformer Q-Networks (DTQN), a novel architecture utilizing transformers and self-attention to encode an agent's history. DTQN is designed modularly, and we compare results against several modifications to our base model. Our experiments demonstrate the transformer can solve partially observable tasks faster and more stably than previous recurrent approaches.
Paper Structure (32 sections, 4 equations, 12 figures, 3 tables, 1 algorithm)

This paper contains 32 sections, 4 equations, 12 figures, 3 tables, 1 algorithm.

Figures (12)

  • Figure 1: Architectural diagram of DTQN. Each observation in the history is embedded independently, and Q-values are generated for each observation sub-history. Only the last set of Q-values are used to select the next action, but the other Q-values can be utilized for training.
  • Figure 2: Results showing the success rate of DTQN against baselines. DTQN is shown in blue, a simple attention network (ATTN) shown in brown, Deep Recurrent Q-Network (DRQN) hausknecht2015deep is shown in orange, and Deep Q-Network (DQN) mnih2015human is shown in purple. Lines show the mean and shaded regions represent standard error across 5 random seeds. DTQN excels both in terms of learning speed as well as final performance, clearly outperforming the baselines on nearly all domains. Refer to section \ref{['sec:baseline-results']} for discussion of results.
  • Figure 3: Gym-Gridverse Memory domains. The top row depicts the state while the bottom row shows the agent's current observation. The colored beacon informs the agent which flag to reach.
  • Figure 4: Attention bars for gridverse memory 7x7. Bars go from left to right, and observations go top to bottom (i.e. the second observation attended to the first and second observation). Attention weights below 0.2 have been removed for visibility.
  • Figure 5: Baseline comparison of DTQN (blue) with DRQN (orange), DARQN (yellow), ADRQN (maroon), and DQN (purple), measured by evaluation success rate during training. Lines show mean success rate and shaded regions represent standard error across five random seeds. In all three cases, DTQN achieves the highest final success rate among all five algorithms.
  • ...and 7 more figures