Table of Contents
Fetching ...

SiT: Symmetry-Invariant Transformers for Generalisation in Reinforcement Learning

Matthias Weissenbacher, Rishabh Agarwal, Yoshinobu Kawahara

TL;DR

Symmetry-Invariant Transformer (SiT), a scalable vision transformer that leverages both local and global data patterns in a self-supervised manner to improve generalisation, is introduced.

Abstract

An open challenge in reinforcement learning (RL) is the effective deployment of a trained policy to new or slightly different situations as well as semantically-similar environments. We introduce Symmetry-Invariant Transformer (SiT), a scalable vision transformer (ViT) that leverages both local and global data patterns in a self-supervised manner to improve generalisation. Central to our approach is Graph Symmetric Attention, which refines the traditional self-attention mechanism to preserve graph symmetries, resulting in invariant and equivariant latent representations. We showcase SiT's superior generalization over ViTs on MiniGrid and Procgen RL benchmarks, and its sample efficiency on Atari 100k and CIFAR10.

SiT: Symmetry-Invariant Transformers for Generalisation in Reinforcement Learning

TL;DR

Symmetry-Invariant Transformer (SiT), a scalable vision transformer that leverages both local and global data patterns in a self-supervised manner to improve generalisation, is introduced.

Abstract

An open challenge in reinforcement learning (RL) is the effective deployment of a trained policy to new or slightly different situations as well as semantically-similar environments. We introduce Symmetry-Invariant Transformer (SiT), a scalable vision transformer (ViT) that leverages both local and global data patterns in a self-supervised manner to improve generalisation. Central to our approach is Graph Symmetric Attention, which refines the traditional self-attention mechanism to preserve graph symmetries, resulting in invariant and equivariant latent representations. We showcase SiT's superior generalization over ViTs on MiniGrid and Procgen RL benchmarks, and its sample efficiency on Atari 100k and CIFAR10.
Paper Structure (36 sections, 2 theorems, 24 equations, 16 figures, 7 tables)

This paper contains 36 sections, 2 theorems, 24 equations, 16 figures, 7 tables.

Key Result

Proposition 3.1

The GSA mechanism (equation eq:defatt3) represents a symmetry-preserving module. It may be both invariant and/or equivariant w.r.t. symmetries of the input. The corresponding symmetry is dictated by the various graph selections. To achieve rotation invariance, the subsequent application of equation

Figures (16)

  • Figure 1: Local (patch-wise) and global transformations of observations of the CaveFlyer environment, Procgen suite cobbe20a. Permutation invariant agents tang2021the can't discern key features (a) in contrast to agents with local and/or global flip and rotation invariance (b) and (c).
  • Figure 2: Composition choices of the graph matrix $G \in \mathbb{R}^{P \times P}$ for $P=9$ to preserve different symmetries. Same colours in $G$ represent shared weights. In (c) flips change the orientation of directed triangles i.e. clockwise to anti-clockwise while $90^\circ$-rotations preserve it.
  • Figure 3: SiT model architecture wit local and global GSA modules.
  • Figure 4: Comparing SiTs with CNNs and ViTs, in terms of training and generalization performance on LavaCrossing environments. SiTs substantially outperform both CNNs and ViTs.
  • Figure 5: Train vs. test observations of the Mini-grid Lavacrossing (easy-N1) environment. We test generalisation of agents to varying goal and starting positions.
  • ...and 11 more figures

Theorems & Definitions (3)

  • Proposition 3.1: Symmetry Guarantee
  • Proposition 5.1: Symmetry Guarantee
  • Definition 1: Graph Matrix