Table of Contents
Fetching ...

Sparse Masked Attention Policies for Reliable Generalization

Caroline Horsch, Laurens Engwegen, Max Weltevrede, Matthijs T. J. Spaan, Wendelin Böhmer

TL;DR

This paper uses a learned masking function which operates on, and is integrated with, the attention weights within an attention-based policy network to improve policy generalization to unseen tasks in the Procgen benchmark compared to standard PPO and masking approaches.

Abstract

In reinforcement learning, abstraction methods that remove unnecessary information from the observation are commonly used to learn policies which generalize better to unseen tasks. However, these methods often overlook a crucial weakness: the function which extracts the reduced-information representation has unknown generalization ability in unseen observations. In this paper, we address this problem by presenting an information removal method which more reliably generalizes to new states. We accomplish this by using a learned masking function which operates on, and is integrated with, the attention weights within an attention-based policy network. We demonstrate that our method significantly improves policy generalization to unseen tasks in the Procgen benchmark compared to standard PPO and masking approaches.

Sparse Masked Attention Policies for Reliable Generalization

TL;DR

This paper uses a learned masking function which operates on, and is integrated with, the attention weights within an attention-based policy network to improve policy generalization to unseen tasks in the Procgen benchmark compared to standard PPO and masking approaches.

Abstract

In reinforcement learning, abstraction methods that remove unnecessary information from the observation are commonly used to learn policies which generalize better to unseen tasks. However, these methods often overlook a crucial weakness: the function which extracts the reduced-information representation has unknown generalization ability in unseen observations. In this paper, we address this problem by presenting an information removal method which more reliably generalizes to new states. We accomplish this by using a learned masking function which operates on, and is integrated with, the attention weights within an attention-based policy network. We demonstrate that our method significantly improves policy generalization to unseen tasks in the Procgen benchmark compared to standard PPO and masking approaches.
Paper Structure (27 sections, 9 equations, 8 figures, 2 tables)

This paper contains 27 sections, 9 equations, 8 figures, 2 tables.

Figures (8)

  • Figure 1: A demonstration of the importance of removing unnecessary information from observations for the purpose of policy generalization. In this task, the agent needs to collect the coin at the far right. Both states look very different, but share the same immediate policy (move right). This decision can be reduced from considering the entire observation to just some areas highlighted in green (around the agent, coin location, and position to the right of the agent). In this reduction, the states are almost identical.
  • Figure 2: The CNN feature extractor and positional encoding setup for processing inputs. This set of features are the input tokens to the attention block.
  • Figure 3: The normalized returns in unseen tasks for each of the Procgen games. The mean and standard error over 10 seeds is plotted.
  • Figure 4: The test returns (mean and standard error over 10 seeds) on unseen tasks from the dodgeball and maze environments. As expected, sparse masking results in a large improvement in generalization for dodgeball, which has an underlying sparse dependent policy structure, but shows much smaller improvement in maze, whose underlying policy structure is relatively dense.
  • Figure 5: The normalized test return (mean and standard error over 10 seeds) of our method for different sparsity levels $\alpha$ in the dodgeball and maze environments. For more environments see Figure \ref{['fig:all-sparsitiy-levels']} in the appendix.
  • ...and 3 more figures