Table of Contents
Fetching ...

Disentangling Representations through Multi-task Learning

Pantelis Vafidis, Aman Bhargava, Antonio Rangel

TL;DR

The paper addresses how agents learning multiple tasks from noisy evidence can develop disentangled, abstract world-model representations. It introduces a formal framework where the latent state $\mathbf Z(t)$ encodes a maximum-likelihood estimate $\mu(t)$ of the ground-truth $\mathbf x^*$, and proves that disentanglement emerges when the number of tasks $N_{ ext{task}}$ exceeds the latent dimensionality $D$ under non-zero noise. The authors validate the theory with autoregressive models (RNNs/LSTMs) and GPT-2 transformers, showing continuous 2D attractors, orthogonal latent factors, and strong zero-shot OOD generalization, while demonstrating robustness across architectures and boundary geometries. The work has broad implications for representation learning and neuroscience by linking multitask competence to world-model formation and suggesting conditions under which neural and artificial systems converge to interpretable, topology-preserving representations. These insights offer a principled account for why complex models, including LLMs, can acquire human-aligned concepts and generalize to unseen scenarios through disentangled latent structure.

Abstract

Intelligent perception and interaction with the world hinges on internal representations that capture its underlying structure (''disentangled'' or ''abstract'' representations). Disentangled representations serve as world models, isolating latent factors of variation in the world along approximately orthogonal directions, thus facilitating feature-based generalization. We provide experimental and theoretical results guaranteeing the emergence of disentangled representations in agents that optimally solve multi-task evidence accumulation classification tasks, canonical in the neuroscience literature. The key conceptual finding is that, by producing accurate multi-task classification estimates, a system implicitly represents a set of coordinates specifying a disentangled representation of the underlying latent state of the data it receives. The theory provides conditions for the emergence of these representations in terms of noise, number of tasks, and evidence accumulation time. We experimentally validate these predictions in RNNs trained to multi-task, which learn disentangled representations in the form of continuous attractors, leading to zero-shot out-of-distribution (OOD) generalization in predicting latent factors. We demonstrate the robustness of our framework across autoregressive architectures, decision boundary geometries and in tasks requiring classification confidence estimation. We find that transformers are particularly suited for disentangling representations, which might explain their unique world understanding abilities. Overall, our framework establishes a formal link between competence at multiple tasks and the formation of disentangled, interpretable world models in both biological and artificial systems, and helps explain why ANNs often arrive at human-interpretable concepts, and how they both may acquire exceptional zero-shot generalization capabilities.

Disentangling Representations through Multi-task Learning

TL;DR

The paper addresses how agents learning multiple tasks from noisy evidence can develop disentangled, abstract world-model representations. It introduces a formal framework where the latent state encodes a maximum-likelihood estimate of the ground-truth , and proves that disentanglement emerges when the number of tasks exceeds the latent dimensionality under non-zero noise. The authors validate the theory with autoregressive models (RNNs/LSTMs) and GPT-2 transformers, showing continuous 2D attractors, orthogonal latent factors, and strong zero-shot OOD generalization, while demonstrating robustness across architectures and boundary geometries. The work has broad implications for representation learning and neuroscience by linking multitask competence to world-model formation and suggesting conditions under which neural and artificial systems converge to interpretable, topology-preserving representations. These insights offer a principled account for why complex models, including LLMs, can acquire human-aligned concepts and generalize to unseen scenarios through disentangled latent structure.

Abstract

Intelligent perception and interaction with the world hinges on internal representations that capture its underlying structure (''disentangled'' or ''abstract'' representations). Disentangled representations serve as world models, isolating latent factors of variation in the world along approximately orthogonal directions, thus facilitating feature-based generalization. We provide experimental and theoretical results guaranteeing the emergence of disentangled representations in agents that optimally solve multi-task evidence accumulation classification tasks, canonical in the neuroscience literature. The key conceptual finding is that, by producing accurate multi-task classification estimates, a system implicitly represents a set of coordinates specifying a disentangled representation of the underlying latent state of the data it receives. The theory provides conditions for the emergence of these representations in terms of noise, number of tasks, and evidence accumulation time. We experimentally validate these predictions in RNNs trained to multi-task, which learn disentangled representations in the form of continuous attractors, leading to zero-shot out-of-distribution (OOD) generalization in predicting latent factors. We demonstrate the robustness of our framework across autoregressive architectures, decision boundary geometries and in tasks requiring classification confidence estimation. We find that transformers are particularly suited for disentangling representations, which might explain their unique world understanding abilities. Overall, our framework establishes a formal link between competence at multiple tasks and the formation of disentangled, interpretable world models in both biological and artificial systems, and helps explain why ANNs often arrive at human-interpretable concepts, and how they both may acquire exceptional zero-shot generalization capabilities.
Paper Structure (68 sections, 12 theorems, 43 equations, 17 figures, 1 table)

This paper contains 68 sections, 12 theorems, 43 equations, 17 figures, 1 table.

Key Result

Theorem 3.1

If $\mathbf C \in \mathbb R^{N_\textit{task}\times D}$ is a full-rank matrix and $N_\textit{task} \geq D$ and noise $\sigma > 0$, then Specifically, $\mu(t)$ is the maximum likelihood estimate (MLE) of $\mathbf x^*$ given observations $f(\mathbf X(1)), \dots, f(\mathbf X(t))$. A closed-form expression for extracting $\mu(t)$ from $\mathbf Z(t)$ if $N_\textit{task} \geq D$ is: where $\Phi$ is the

Figures (17)

  • Figure 1: Disentangled representations and a framework to learn them. (a) A disentangled representation directly lends itself to OOD generalization: a downstream linear decoder that can differentiate ripe from unripe bananas can readily generalize to mangos, even though it has never been trained on mangos. (b) Overview of our multi-task classification framework. A ground truth $\mathbf{x}^*$ is sampled and Gaussian noise is added to arrive at observations $\{\mathbf X(1), ..., \mathbf X(t)\}$. These observations are processed by the filter-based model illustrated graphically in Figure \ref{['fig:theory_setup']}, maintaining a latent state $\mathbf{Z}(t)$. The latent state $\mathbf{Z}(t)$ is then used to produce classification outputs $\hat{Y}_1(t)$, $\hat{Y}_2(t)$. Theorem \ref{['thm:optimal_reps']} proves that $\mathbf{Z}(t)$ must encode the optimal estimator of $\mathbf{x}^*$ given the noisy observations, $\mu(t)$.
  • Figure 2: Data generation and architecture.(a) For each trial, we sample a ground truth vector $\mathbf x^*$, and add IID noise to arrive at $\mathbf X(t)$. The task is to report whether $\mathbf x^*$ lies above ($1$) or below ($0$) each of the classification lines (color matches corresponding boolean variable in $y$), given the noisy and non-linearly transformed samples $f(\mathbf X(1)), \dots, f(\mathbf X(t))$. (b) Models (RNN depicted) are trained to report the outcome of all the binary classifications in a at the end of the trial (indicated by the fixation input turning 0).
  • Figure 3: Learning disentangled representations.(a) ID and OOD generalization performance for networks trained in different number of tasks $N_\textit{task}$. We report the 25, 50 and 75 percentile of $r^2$ for each network size (see \ref{['app:r_sq']}). ID and OOD performance increase with $N_\textit{task}$, and the generalization gap decreases, indicating that the networks have learned abstract representations. (b) The results hold for other autoregressive architectures, including LSTMs and GPT-2 transformers. (c) Angles between latent factor decoders (see \ref{['app:angles']} for how they were estimated). The angles approach 90 degrees as $N_{\text{task}} \gg D$ for RNNs, but already fror $N_{\text{task}} \geq D$ for transformers. Remaining errors around 90 degrees are attributed to variability in the linear decoder fits. (d) Top 3 PCs of RNN activity ($N_\textit{task}=24$, $D=2$), capturing 85% of variance (see inset). Each line is a trial, while color saturation indicates time. All trials start from the center and move outwards, towards the location of $\mathbf x^*$ in state space. We color the last timepoint in each trial (squares) according to the quadrant this trial was drawn from. Red x's correspond to attractors (see \ref{['app:fixed_points']}). Here we remove input noise so that trajectories can be visualized easier. The network learns a two-dimensional continuous attractor that provides a disentangled representation of the 2D state space. (e) Spectral plot resulting from linearizing RNN dynamics around every fixed point (\ref{['app:fixed_points']}). First two eigenvalues of the difference system are near $0$, while the rest decay much faster, indicating marginal stability across two dimensions for every fixed point, a signature of a 2D continuous attractor.
  • Figure 4: RNN and GPT representations and relation to latent variables. (a) Hidden layer activations of RNN in \ref{['fig:data_arch']}b (left) and GPT-2 transformer (right), while systematically varying the latent factors $x_1$ and $x_2$ from -0.5 to 0.5. Activations are plotted in 8*8 grids, one for each value of $x_1$ and $x_2$. Each grid contains firing rates for a total of 64 neurons for the RNN, and activations for 8 units for each of the 8 heads from the final embedding of the sequence for GPT-2. (b) Correlation coefficient of activations for both models with $x_1$ and $x_2$, respectively.
  • Figure 5: Experiments confirm theoretical predictions.(a) OOD $r^2$ for free RT RNN required to report its estimate of $\mathbf x^*$ at different times (see \ref{['app:free_rt']} for details, $N_\textit{task}=24$, $D=2$). Maximum network $r^2$ matches optimal multi-task classifier theory predictions (Equation \ref{['eqn:theor_r_sq']} in \ref{['app:theor_r_sq']}). (b) OOD $r^2$ as a function of input dimensionality $D$ and number of tasks $N_\textit{task}$. Good values of $r^2$ are obtained when $N_\textit{task} \geq D$, especially for GPT models, confirming our theoretical results. (c) Increasing amounts of noise in pretraining results in better OOD generalization ($D=2$).
  • ...and 12 more figures

Theorems & Definitions (23)

  • Theorem 3.1: Disentangled Representation Theorem
  • proof
  • Lemma B.1
  • proof
  • Lemma B.2
  • proof
  • Lemma B.3
  • proof
  • Theorem B.4: Trilateration Theorem
  • proof
  • ...and 13 more