Table of Contents
Fetching ...

One-shot World Models Using a Transformer Trained on a Synthetic Prior

Fabio Ferreira, Moreno Schlageter, Raghu Rajan, Andre Biedenkapp, Frank Hutter

TL;DR

One-Shot World Model (OSWM) is proposed, a transformer world model that is learned in an in-context learning fashion from purely synthetic data sampled from a prior distribution that is able to quickly adapt to the dynamics of a simple grid world and a custom control environment.

Abstract

A World Model is a compressed spatial and temporal representation of a real world environment that allows one to train an agent or execute planning methods. However, world models are typically trained on observations from the real world environment, and they usually do not enable learning policies for other real environments. We propose One-Shot World Model (OSWM), a transformer world model that is learned in an in-context learning fashion from purely synthetic data sampled from a prior distribution. Our prior is composed of multiple randomly initialized neural networks, where each network models the dynamics of each state and reward dimension of a desired target environment. We adopt the supervised learning procedure of Prior-Fitted Networks by masking next-state and reward at random context positions and query OSWM to make probabilistic predictions based on the remaining transition context. During inference time, OSWM is able to quickly adapt to the dynamics of a simple grid world, as well as the CartPole gym and a custom control environment by providing 1k transition steps as context and is then able to successfully train environment-solving agent policies. However, transferring to more complex environments remains a challenge, currently. Despite these limitations, we see this work as an important stepping-stone in the pursuit of learning world models purely from synthetic data.

One-shot World Models Using a Transformer Trained on a Synthetic Prior

TL;DR

One-Shot World Model (OSWM) is proposed, a transformer world model that is learned in an in-context learning fashion from purely synthetic data sampled from a prior distribution that is able to quickly adapt to the dynamics of a simple grid world and a custom control environment.

Abstract

A World Model is a compressed spatial and temporal representation of a real world environment that allows one to train an agent or execute planning methods. However, world models are typically trained on observations from the real world environment, and they usually do not enable learning policies for other real environments. We propose One-Shot World Model (OSWM), a transformer world model that is learned in an in-context learning fashion from purely synthetic data sampled from a prior distribution. Our prior is composed of multiple randomly initialized neural networks, where each network models the dynamics of each state and reward dimension of a desired target environment. We adopt the supervised learning procedure of Prior-Fitted Networks by masking next-state and reward at random context positions and query OSWM to make probabilistic predictions based on the remaining transition context. During inference time, OSWM is able to quickly adapt to the dynamics of a simple grid world, as well as the CartPole gym and a custom control environment by providing 1k transition steps as context and is then able to successfully train environment-solving agent policies. However, transferring to more complex environments remains a challenge, currently. Despite these limitations, we see this work as an important stepping-stone in the pursuit of learning world models purely from synthetic data.
Paper Structure (27 sections, 1 equation, 7 figures, 7 tables, 1 algorithm)

This paper contains 27 sections, 1 equation, 7 figures, 7 tables, 1 algorithm.

Figures (7)

  • Figure 1: OSWM is trained on synthetic data sampled from a prior distribution of randomly initialized, untrained neural networks that mimic RL environments (left). Given a sequence of synthetic interactions, OSWM is optimized by predicting future dynamics at random cut-offs (center). RL agents can then be trained on OSWM to solve simple real environments given a context.
  • Figure 2: Evaluation scores for RL agent training on the OSWM for GridWorld, CartPole-v0, and SimpleEnv. Blue shows the mean over 3 runs, with the standard deviation in light blue. Orange and green depict the best and worst-performing agents, respectively.
  • Figure 3: Typical distribution patterns generated by the NN prior: (a) highly peaked, (b) wide or smoother, and (c) multi-modal distributions.
  • Figure 4: Reward distributions for the real and OSWM GridWorld and CartPole environments.
  • Figure 5: Typical distribution patterns generated by the Momentum prior: (a) broad, (b) multi-modal, and (c) sparse distributions.
  • ...and 2 more figures