Disentangling Representations through Multi-task Learning
Pantelis Vafidis, Aman Bhargava, Antonio Rangel
TL;DR
The paper addresses how agents learning multiple tasks from noisy evidence can develop disentangled, abstract world-model representations. It introduces a formal framework where the latent state $\mathbf Z(t)$ encodes a maximum-likelihood estimate $\mu(t)$ of the ground-truth $\mathbf x^*$, and proves that disentanglement emerges when the number of tasks $N_{ ext{task}}$ exceeds the latent dimensionality $D$ under non-zero noise. The authors validate the theory with autoregressive models (RNNs/LSTMs) and GPT-2 transformers, showing continuous 2D attractors, orthogonal latent factors, and strong zero-shot OOD generalization, while demonstrating robustness across architectures and boundary geometries. The work has broad implications for representation learning and neuroscience by linking multitask competence to world-model formation and suggesting conditions under which neural and artificial systems converge to interpretable, topology-preserving representations. These insights offer a principled account for why complex models, including LLMs, can acquire human-aligned concepts and generalize to unseen scenarios through disentangled latent structure.
Abstract
Intelligent perception and interaction with the world hinges on internal representations that capture its underlying structure (''disentangled'' or ''abstract'' representations). Disentangled representations serve as world models, isolating latent factors of variation in the world along approximately orthogonal directions, thus facilitating feature-based generalization. We provide experimental and theoretical results guaranteeing the emergence of disentangled representations in agents that optimally solve multi-task evidence accumulation classification tasks, canonical in the neuroscience literature. The key conceptual finding is that, by producing accurate multi-task classification estimates, a system implicitly represents a set of coordinates specifying a disentangled representation of the underlying latent state of the data it receives. The theory provides conditions for the emergence of these representations in terms of noise, number of tasks, and evidence accumulation time. We experimentally validate these predictions in RNNs trained to multi-task, which learn disentangled representations in the form of continuous attractors, leading to zero-shot out-of-distribution (OOD) generalization in predicting latent factors. We demonstrate the robustness of our framework across autoregressive architectures, decision boundary geometries and in tasks requiring classification confidence estimation. We find that transformers are particularly suited for disentangling representations, which might explain their unique world understanding abilities. Overall, our framework establishes a formal link between competence at multiple tasks and the formation of disentangled, interpretable world models in both biological and artificial systems, and helps explain why ANNs often arrive at human-interpretable concepts, and how they both may acquire exceptional zero-shot generalization capabilities.
