Table of Contents
Fetching ...

Machine Theory of Mind

Neil C. Rabinowitz, Frank Perbet, H. Francis Song, Chiyuan Zhang, S. M. Ali Eslami, Matthew Botvinick

TL;DR

The paper presents ToMnet, a meta-learned observer that builds machine Theory of Mind by learning a general prior over agent behavior and an agent-specific posterior from observed trajectories. It uses a three-part architecture (character net, mental state net, prediction net) to predict future actions, object consumptions, and latent representations across diverse agent species in gridworld POMDPs. Through experiments with random, algorithmic, and deep RL agents, the work demonstrates both implicit and explicit belief modeling, including false beliefs, and shows how disentangled embeddings can reveal underlying behavioral abstractions. These results point toward interpretable, sample-efficient multi-agent AI capable of rapid adaptation to new agents and scenarios, while highlighting limitations and avenues for scaling to richer environments.

Abstract

Theory of mind (ToM; Premack & Woodruff, 1978) broadly refers to humans' ability to represent the mental states of others, including their desires, beliefs, and intentions. We propose to train a machine to build such models too. We design a Theory of Mind neural network -- a ToMnet -- which uses meta-learning to build models of the agents it encounters, from observations of their behaviour alone. Through this process, it acquires a strong prior model for agents' behaviour, as well as the ability to bootstrap to richer predictions about agents' characteristics and mental states using only a small number of behavioural observations. We apply the ToMnet to agents behaving in simple gridworld environments, showing that it learns to model random, algorithmic, and deep reinforcement learning agents from varied populations, and that it passes classic ToM tasks such as the "Sally-Anne" test (Wimmer & Perner, 1983; Baron-Cohen et al., 1985) of recognising that others can hold false beliefs about the world. We argue that this system -- which autonomously learns how to model other agents in its world -- is an important step forward for developing multi-agent AI systems, for building intermediating technology for machine-human interaction, and for advancing the progress on interpretable AI.

Machine Theory of Mind

TL;DR

The paper presents ToMnet, a meta-learned observer that builds machine Theory of Mind by learning a general prior over agent behavior and an agent-specific posterior from observed trajectories. It uses a three-part architecture (character net, mental state net, prediction net) to predict future actions, object consumptions, and latent representations across diverse agent species in gridworld POMDPs. Through experiments with random, algorithmic, and deep RL agents, the work demonstrates both implicit and explicit belief modeling, including false beliefs, and shows how disentangled embeddings can reveal underlying behavioral abstractions. These results point toward interpretable, sample-efficient multi-agent AI capable of rapid adaptation to new agents and scenarios, while highlighting limitations and avenues for scaling to richer environments.

Abstract

Theory of mind (ToM; Premack & Woodruff, 1978) broadly refers to humans' ability to represent the mental states of others, including their desires, beliefs, and intentions. We propose to train a machine to build such models too. We design a Theory of Mind neural network -- a ToMnet -- which uses meta-learning to build models of the agents it encounters, from observations of their behaviour alone. Through this process, it acquires a strong prior model for agents' behaviour, as well as the ability to bootstrap to richer predictions about agents' characteristics and mental states using only a small number of behavioural observations. We apply the ToMnet to agents behaving in simple gridworld environments, showing that it learns to model random, algorithmic, and deep reinforcement learning agents from varied populations, and that it passes classic ToM tasks such as the "Sally-Anne" test (Wimmer & Perner, 1983; Baron-Cohen et al., 1985) of recognising that others can hold false beliefs about the world. We argue that this system -- which autonomously learns how to model other agents in its world -- is an important step forward for developing multi-agent AI systems, for building intermediating technology for machine-human interaction, and for advancing the progress on interpretable AI.

Paper Structure

This paper contains 30 sections, 6 equations, 15 figures, 1 table.

Figures (15)

  • Figure 1: ToMnet architecture. The character net parses an agent's past trajectories from a set of POMDPs to form a character embedding, $e_{\mathrm{char}}$. The mental state net parses the agent's trajectory on the current episode, to form an embedding of its mental state, $e_{\mathrm{mental}}$. These embeddings are fed into the prediction net, which is then queried with a current state. This outputs predictions about future behaviour, such as next-step action probabilities ($\hat{\pi}$), probabilities of whether certain objects will be consumed ($\hat{c}$), and predicted successor representations dayan1993improving.
  • Figure 2: Example gridworld in which a random agent acts.(a) Example past episode. Coloured squares indicate objects. Red arrows indicate the positions and actions taken by the agent. (b) Example query: a state from a new MDP. Black dot indicates agent position. (c) Predictions for the next action taken by the agent shown in (a) in query state (b). Top: prediction from ToMnet trained on agents with near-deterministic policies. Bottom: prediction from ToMnet trained on agents with more stochastic policies.
  • Figure 3: ToMnet trained on random agents.(a) Likelihood of agent's true actions under the ToMnet's predictions, given that the ToMnet has been trained on species $\mathcal{S}(\alpha)$. Priors are shown in light blue, and posteriors after observing that agent perform just that same action in $N_{\mathrm{past}} = 1$ or $5$ past episodes in darker blue. Dots are data from the ToMnet; solid lines are from the analytic Bayes-optimal posteriors specialised to the respective $\mathcal{S}(\alpha)$. (b) Character embeddings $e_{\mathrm{char}} \in \mathbb{R}^2$ of different agents. Dots are coloured by which action was observed to occur most during $N_{\mathrm{past}}=10$ past episodes, and are darker the higher that count. (c) Average KL-divergence between agents' true and predicted policies when the ToMnet is trained on agents from one species, $\mathcal{S}(\alpha)$, but tested on agents from a different species $\mathcal{S}(\alpha^\prime)$. Dots show values from the ToMnet; lines show analytic expected KLs when using analytic Bayes-optimal inference as in (a). Values calculated for $N_{\mathrm{past}}=1$. The ToMnet thus learns an effective prior for the species it is trained on. (d) Same, but including a ToMnet trained on a mixture of species (with $N_{\mathrm{past}}=5$). The ToMnet here implicitly learns to perform hierarchical inference.
  • Figure 4: ToMnet on goal-driven agents.(a) Past trajectory of an example agent. Coloured squares indicate the four objects. Red arrows indicate the position and action taken by the agent. (b) Example query: a state from a new MDP. Black dot indicates agent position. (c) ToMnet's prediction for the agent's next action (top) and object consumed at the end of the episode (bottom) for the query MDP in (b), given the past observation in (a). (d) ToMnet's prediction of the successor representation (SR) for query (b), using discount $\gamma = 0.9$. Darker shading indicates higher expected discounted state occupancy.
  • Figure 5: ToMnet on goal-driven agents, continued.(a) This ToMnet sees only snapshots of single observation/action pairs (red arrow) from a variable number of past episodes (one shown here). (b) Increasing $N_{\mathrm{past}}$ leads to better predictions; here we show the average posterior probability assigned to the true action. Even when $N_{\mathrm{past}}=0$, the action probability is greater than chance, since all agents in the species have similar policies in some regions of the state space. (c) Predicted policy for different initial agent locations in a query MDP, for different numbers of past observations. Arrows show resultant vectors for the predicted policies, i.e. $\sum_k \mathbf{a_k} \cdot \hat{\pi}(\mathbf{a_k} | x, e_{\mathrm{char}})$. When $N_{\mathrm{past}} = 0$, the ToMnet has no information about the agent's preferred object, so the predicted policy exhibits no net object preference. When $N_{\mathrm{past}}>0$, the ToMnet infers a preference for the pink object. When the agent is stuck in the top right chamber, the ToMnet predicts that it will always consume the blue object, as this terminates the episode as soon as possible, avoiding a costly penalty. (d) 2D embedding space of the ToMnet, showing values of $e_{\mathrm{char}}$ from 100 different agents. Agents are colour-coded by their ground-truth preferred objects; saturation increases with $N_{\mathrm{past}}$, with the grey dots in the centre denoting agents with $N_{\mathrm{past}}=0$.
  • ...and 10 more figures