Table of Contents
Fetching ...

Learning Cognitive Maps from Transformer Representations for Efficient Planning in Partially Observed Environments

Antoine Dedieu, Wolfgang Lehrach, Guangyao Zhou, Dileep George, Miguel Lázaro-Gredilla

TL;DR

The paper tackles planning in partially observed environments by enabling explicit world-model extraction from transformer representations. It introduces the Transformer with Discrete Bottlenecks (TDB), which compresses history into latent codes that define an interpretable cognitive map, enabling an external planner to efficiently compute short paths. Empirical results across 2D/3D navigation and a text dataset show that TDB preserves near-vanilla predictive performance while delivering exponential speedups in planning and emergent in-context capabilities, with multiple bottlenecks accelerating training and improving robustness. While limitations include reliance on categorical inputs and imperfect disentanglement, the approach offers a scalable route to planning-compatible world models and latent graphs derived from transformer representations.

Abstract

Despite their stellar performance on a wide range of tasks, including in-context tasks only revealed during inference, vanilla transformers and variants trained for next-token predictions (a) do not learn an explicit world model of their environment which can be flexibly queried and (b) cannot be used for planning or navigation. In this paper, we consider partially observed environments (POEs), where an agent receives perceptually aliased observations as it navigates, which makes path planning hard. We introduce a transformer with (multiple) discrete bottleneck(s), TDB, whose latent codes learn a compressed representation of the history of observations and actions. After training a TDB to predict the future observation(s) given the history, we extract interpretable cognitive maps of the environment from its active bottleneck(s) indices. These maps are then paired with an external solver to solve (constrained) path planning problems. First, we show that a TDB trained on POEs (a) retains the near perfect predictive performance of a vanilla transformer or an LSTM while (b) solving shortest path problems exponentially faster. Second, a TDB extracts interpretable representations from text datasets, while reaching higher in-context accuracy than vanilla sequence models. Finally, in new POEs, a TDB (a) reaches near-perfect in-context accuracy, (b) learns accurate in-context cognitive maps (c) solves in-context path planning problems.

Learning Cognitive Maps from Transformer Representations for Efficient Planning in Partially Observed Environments

TL;DR

The paper tackles planning in partially observed environments by enabling explicit world-model extraction from transformer representations. It introduces the Transformer with Discrete Bottlenecks (TDB), which compresses history into latent codes that define an interpretable cognitive map, enabling an external planner to efficiently compute short paths. Empirical results across 2D/3D navigation and a text dataset show that TDB preserves near-vanilla predictive performance while delivering exponential speedups in planning and emergent in-context capabilities, with multiple bottlenecks accelerating training and improving robustness. While limitations include reliance on categorical inputs and imperfect disentanglement, the approach offers a scalable route to planning-compatible world models and latent graphs derived from transformer representations.

Abstract

Despite their stellar performance on a wide range of tasks, including in-context tasks only revealed during inference, vanilla transformers and variants trained for next-token predictions (a) do not learn an explicit world model of their environment which can be flexibly queried and (b) cannot be used for planning or navigation. In this paper, we consider partially observed environments (POEs), where an agent receives perceptually aliased observations as it navigates, which makes path planning hard. We introduce a transformer with (multiple) discrete bottleneck(s), TDB, whose latent codes learn a compressed representation of the history of observations and actions. After training a TDB to predict the future observation(s) given the history, we extract interpretable cognitive maps of the environment from its active bottleneck(s) indices. These maps are then paired with an external solver to solve (constrained) path planning problems. First, we show that a TDB trained on POEs (a) retains the near perfect predictive performance of a vanilla transformer or an LSTM while (b) solving shortest path problems exponentially faster. Second, a TDB extracts interpretable representations from text datasets, while reaching higher in-context accuracy than vanilla sequence models. Finally, in new POEs, a TDB (a) reaches near-perfect in-context accuracy, (b) learns accurate in-context cognitive maps (c) solves in-context path planning problems.
Paper Structure (57 sections, 18 equations, 13 figures, 9 tables)

This paper contains 57 sections, 18 equations, 13 figures, 9 tables.

Figures (13)

  • Figure 1: An agent is trained on random walks in an aliased room with no reward and unknown layout. At test time, given a novel random walk (in fuchsia), it has to find a shortest path (in red) between room positions A and B. While a vanilla transformer solves this path planning problem with forward rollouts, which can be exponentially expensive due to aliasing, our transformer variant pairs its learned cognitive map with an external solver.
  • Figure 2: Our proposed transformer with a (single) discrete bottleneck. The respective linear embeddings of observations and actions go through a causal transformer. The observation outputs are compressed by the vector quantizer, then concatenated with the next action embedding in order to predict the next observation. Finally, a cognitive map of the environment is built from the active bottleneck indices.
  • Figure 3: [Left] Top: $2$D aliased room of size $\small{15\times20}$ with $O=4$ unique observations and $2$ identical $4\times4$ patches (in fuchsia). Bottom: Aliased cube of edge size $6$ with $O=12$ and non-Euclidean dynamics. [Center left] Top: cognitive map learned with a TDB($S=3, M=1$). For visualization, each latent node is mapped with the observation (resp. is placed at the $2$D GT spatial position) with higher empirical frequency when this node is active. Bottom: similar, but we use the Kamada-Kawai algorithm. [Center] Isometric view of a simulated $3$D environment. The agent navigates with egocentric actions and collects RGB images. [Center right] Three cluster centers: the cluster indices serve as categorical observations. [Right] Cognitive map learned with a TDB($S=3, M=4$): each location is represented by four nodes in the latent graph, corresponding to the four agent's heading directions. Also see Fig.\ref{['fig:agix_appendix']}, Appendix \ref{['appendix:agix_environment']}.
  • Figure 4: [Left] For all context lengths $k$, TDB($S=1, M=4$) achieves higher in-context accuracies than an LSTM and a vanilla transformer on the GINC test dataset xie2021explanation---while TDB($S=5, M=1$) is the best model for large contexts. [Right] The learned latent graph of TDB($S=5, M=1$) exhibits five clusters, each corresponding to a color-coded concept.
  • Figure 5: [Left] In novel$2$D aliased test rooms, TDB$(S=3, M=4)$ perfectly (a) predicts the next observation (b) solves in-context path planning problems. These in-context capacities emerge when the number of training rooms increases. A vanilla transformer only solves the prediction problem (a)---which an LSTM struggles to do. [Right] Two latent graphs in-context learned by the TDB on two new test rooms. By design (a) a $3\times3$ fuchsia-coded patch is repeated twice and (b) the room partitions induced by the colors are the same.
  • ...and 8 more figures