Table of Contents
Fetching ...

Improving World Models using Deep Supervision with Linear Probes

Andrii Zahorodnii

TL;DR

The paper addresses robust world representation learning in end-to-end predictive networks by introducing a linear-probe deep supervision term that encourages decoding of true world features from the hidden state. Using a lidar-based Flappy Bird setup, the authors stack an autoencoder and an MDN-LSTM to predict the next latent and an episode end, while a linear probe attempts to recover underlying world variables. They find that increasing the probe weight improves both training and test predictive performance, enhances decodability of world features (including unseen ones), reduces distribution drift in certain regimes, and improves training stability, with the added benefit that a given-size network can match the performance of a larger one. The approach offers practical advantages for compute-constrained settings and robotics by enabling more robust, data-efficient latent representations early in training, potentially reducing deployment costs and enabling on-device reasoning.

Abstract

Developing effective world models is crucial for creating artificial agents that can reason about and navigate complex environments. In this paper, we investigate a deep supervision technique for encouraging the development of a world model in a network trained end-to-end to predict the next observation. While deep supervision has been widely applied for task-specific learning, our focus is on improving the world models. Using an experimental environment based on the Flappy Bird game, where the agent receives only LIDAR measurements as observations, we explore the effect of adding a linear probe component to the network's loss function. This additional term encourages the network to encode a subset of the true underlying world features into its hidden state. Our experiments demonstrate that this supervision technique improves both training and test performance, enhances training stability, and results in more easily decodable world features -- even for those world features which were not included in the training. Furthermore, we observe a reduced distribution drift in networks trained with the linear probe, particularly during high-variability phases of the game (flying between successive pipe encounters). Including the world features loss component roughly corresponded to doubling the model size, suggesting that the linear probe technique is particularly beneficial in compute-limited settings or when aiming to achieve the best performance with smaller models. These findings contribute to our understanding of how to develop more robust and sophisticated world models in artificial agents, paving the way for further advancements in this field.

Improving World Models using Deep Supervision with Linear Probes

TL;DR

The paper addresses robust world representation learning in end-to-end predictive networks by introducing a linear-probe deep supervision term that encourages decoding of true world features from the hidden state. Using a lidar-based Flappy Bird setup, the authors stack an autoencoder and an MDN-LSTM to predict the next latent and an episode end, while a linear probe attempts to recover underlying world variables. They find that increasing the probe weight improves both training and test predictive performance, enhances decodability of world features (including unseen ones), reduces distribution drift in certain regimes, and improves training stability, with the added benefit that a given-size network can match the performance of a larger one. The approach offers practical advantages for compute-constrained settings and robotics by enabling more robust, data-efficient latent representations early in training, potentially reducing deployment costs and enabling on-device reasoning.

Abstract

Developing effective world models is crucial for creating artificial agents that can reason about and navigate complex environments. In this paper, we investigate a deep supervision technique for encouraging the development of a world model in a network trained end-to-end to predict the next observation. While deep supervision has been widely applied for task-specific learning, our focus is on improving the world models. Using an experimental environment based on the Flappy Bird game, where the agent receives only LIDAR measurements as observations, we explore the effect of adding a linear probe component to the network's loss function. This additional term encourages the network to encode a subset of the true underlying world features into its hidden state. Our experiments demonstrate that this supervision technique improves both training and test performance, enhances training stability, and results in more easily decodable world features -- even for those world features which were not included in the training. Furthermore, we observe a reduced distribution drift in networks trained with the linear probe, particularly during high-variability phases of the game (flying between successive pipe encounters). Including the world features loss component roughly corresponded to doubling the model size, suggesting that the linear probe technique is particularly beneficial in compute-limited settings or when aiming to achieve the best performance with smaller models. These findings contribute to our understanding of how to develop more robust and sophisticated world models in artificial agents, paving the way for further advancements in this field.

Paper Structure

This paper contains 12 sections, 4 equations, 8 figures.

Figures (8)

  • Figure 1: Flappy Bird environment with lidar. (A) The environment. (B) The agent only observes the lidar signal as a function of time. (C) The only available actions are no-op and flap. (D-F) The environment provides true variables of the world, such as the player's rotation angle, vertical velocity, and position.
  • Figure 2: Network architecture and training setup. (A) The vision autoencoder compresses the 180-dimensional raw observations into an 8-dimensional latent space. (B) The world model MDN-LSTM takes the current latent observation vector and action as inputs, and predicts the distribution of the next latent vector and an episode end flag. Optionally, the LSTM's hidden state is encouraged to capture true world variables (such as player rotation angle, y position, y velocity, etc.) through a linear probe.
  • Figure 3: The effect of increasing the linear probe weight $\lambda$ on the original (next latent state prediction) loss. Both training (A, B) and test (C) predictive losses decrease as $\lambda$ increases. For every choice of $\lambda$, 20 RNNs were initialized from different random seeds. Error bars in panels (B) and (C) indicate the s.d.
  • Figure 4: Decodability of world features from the network's hidden state for $\lambda=0$ and $\lambda=64$. Every panel represents one world feature. The last three panels (player's vertical position, velocity, and rotation; green) represent features which were explicitly included in the loss function for $\lambda=64$. Note that even the features which were not explicitly part of the losss function exhibit decodability above that of an untrained randomly-initialized network, but only for $\lambda=64$. For every choice of $\lambda$, four RNNs initialized from different random seeds are shown.
  • Figure 5: Distribution drift comparison. Networks trained with the linear probe ($\lambda = 64$) exhibit reduced distribution drift compared to those without the probe ($\lambda = 0$) for both the training policy and a random policy. However, note that this result does not hold for the timesteps near t=35, which correspond to the bird going through the first pair of pipes. The error bars represent s.d. across 10 random seeds.
  • ...and 3 more figures