Table of Contents
Fetching ...

A mathematical theory for understanding when abstract representations emerge in neural networks

Bin Wang, W. Jeffrey Johnston, Stefano Fusi

TL;DR

The paper addresses why abstract, disentangled representations emerge when neural networks are trained on tasks tied to latent variables.It develops a mean‑field analytical framework that converts weight optimization in a two‑layer nonlinear network into a convex optimization over preactivation distributions, with an effective energy $E(\mathbf{h};\rho)$ and a representation kernel $K[\rho]$ that depend only on input/output geometry ($K_X$, $K_Y$).For whitened or target‑aligned inputs, the analysis yields explicit, low‑rank, abstract representations in the hidden layer across ReLU and broad nonlinearities, with a universal kernel form $K[\rho_*] = b_*(d_Y\mathbf{1}\mathbf{1}^T + K_Y)$ and modular neuron tuning in many cases.Extensions to anisotropic geometries and deep architectures show that abstract representations persist and generalize to multi‑layer and recurrent networks, offering a tractable toolkit for understanding task‑driven representation learning.These results connect brain observations of low‑dimensional, abstract coding to a rigorous mathematical mechanism, and provide a general framework for studying representation emergence in permutation‑symmetric, task‑driven networks.

Abstract

Recent experiments reveal that task-relevant variables are often encoded in approximately orthogonal subspaces of the neural activity space. These disentangled low-dimensional representations are observed in multiple brain areas and across different species, and are typically the result of a process of abstraction that supports simple forms of out-of-distribution generalization. The mechanisms by which such geometries emerge remain poorly understood, and the mechanisms that have been investigated are typically unsupervised (e.g., based on variational auto-encoders). Here, we show mathematically that abstract representations of latent variables are guaranteed to appear in the last hidden layer of feedforward nonlinear networks when they are trained on tasks that depend directly on these latent variables. These abstract representations reflect the structure of the desired outputs or the semantics of the input stimuli. To investigate the neural representations that emerge in these networks, we develop an analytical framework that maps the optimization over the network weights into a mean-field problem over the distribution of neural preactivations. Applying this framework to a finite-width ReLU network, we find that its hidden layer exhibits an abstract representation at all global minima of the task objective. We further extend these analyses to two broad families of activation functions and deep feedforward architectures, demonstrating that abstract representations naturally arise in all these scenarios. Together, these results provide an explanation for the widely observed abstract representations in both the brain and artificial neural networks, as well as a mathematically tractable toolkit for understanding the emergence of different kinds of representations in task-optimized, feature-learning network models.

A mathematical theory for understanding when abstract representations emerge in neural networks

TL;DR

The paper addresses why abstract, disentangled representations emerge when neural networks are trained on tasks tied to latent variables.It develops a mean‑field analytical framework that converts weight optimization in a two‑layer nonlinear network into a convex optimization over preactivation distributions, with an effective energy $E(\mathbf{h};\rho)$ and a representation kernel $K[\rho]$ that depend only on input/output geometry ($K_X$, $K_Y$).For whitened or target‑aligned inputs, the analysis yields explicit, low‑rank, abstract representations in the hidden layer across ReLU and broad nonlinearities, with a universal kernel form $K[\rho_*] = b_*(d_Y\mathbf{1}\mathbf{1}^T + K_Y)$ and modular neuron tuning in many cases.Extensions to anisotropic geometries and deep architectures show that abstract representations persist and generalize to multi‑layer and recurrent networks, offering a tractable toolkit for understanding task‑driven representation learning.These results connect brain observations of low‑dimensional, abstract coding to a rigorous mathematical mechanism, and provide a general framework for studying representation emergence in permutation‑symmetric, task‑driven networks.

Abstract

Recent experiments reveal that task-relevant variables are often encoded in approximately orthogonal subspaces of the neural activity space. These disentangled low-dimensional representations are observed in multiple brain areas and across different species, and are typically the result of a process of abstraction that supports simple forms of out-of-distribution generalization. The mechanisms by which such geometries emerge remain poorly understood, and the mechanisms that have been investigated are typically unsupervised (e.g., based on variational auto-encoders). Here, we show mathematically that abstract representations of latent variables are guaranteed to appear in the last hidden layer of feedforward nonlinear networks when they are trained on tasks that depend directly on these latent variables. These abstract representations reflect the structure of the desired outputs or the semantics of the input stimuli. To investigate the neural representations that emerge in these networks, we develop an analytical framework that maps the optimization over the network weights into a mean-field problem over the distribution of neural preactivations. Applying this framework to a finite-width ReLU network, we find that its hidden layer exhibits an abstract representation at all global minima of the task objective. We further extend these analyses to two broad families of activation functions and deep feedforward architectures, demonstrating that abstract representations naturally arise in all these scenarios. Together, these results provide an explanation for the widely observed abstract representations in both the brain and artificial neural networks, as well as a mathematically tractable toolkit for understanding the emergence of different kinds of representations in task-optimized, feature-learning network models.

Paper Structure

This paper contains 20 sections, 58 equations, 8 figures.

Figures (8)

  • Figure 1: Abstract representations. (A) Each stimulus input in the task (e.g. images of handwritten digits) is associated with several binary variables (e.g. parity and magnitude of the digits). An abstract representation is one in which each binary variable is represented along a single axis in the population activity space. This is shown in the top plot inside the frame. This geometric property can be quantified using the parallelism score ($PS$), which measures how parallel the coding directions for one variable remain when the other variables vary. For example, it would measure the parallelism of the parity coding direction for small and large digits (two values of the other variable, magnitude). Abstract representations are low-dimensional and have $PS=1$. In an alternative neural representation, the points representing the different digits are arranged on a tetrahedron shape, which is the highest dimensional representation (bottom in the frame). This would correspond to a non-abstract representation and has $PS\sim0$. (B) Magnitude and parity are equally decodable for both geometries.
  • Figure 2: Model set-up and summary of the main results. (Middle) The two-layer nonlinear network models are trained on tasks related to abstract representation. The weight and bias parameters in each layer are optimized for the task. The hidden layer has width $M$. For a range of input geometries (characterized by different input kernel matrices $K_X$) and the specified output geometry (where each output label is exactly the latent label), the optimal hidden representation is always abstract. (Top right) The output labels for each stimulus input are exactly its associated binary latent labels. (Top left) The range of input geometry changes smoothly from a fully orthogonalized input where different stimuli are represented by orthogonal vectors, to an input geometry that is fully aligned with the outputs. Here to illustrate the orthogonalized input in the $4$-dimensional space, we draw the 3D projection of it that has a tetrahedron shape.
  • Figure 3: The analytical framework. (A) The neural preactivation patterns for all $P$ stimuli can be captured in the preactivation matrix $H$ [Eq. \ref{['eq:preact_matrix']}]. Each row represents the $M$ hidden neurons' preactivations for a specific stimulus. Each column represents the preactivations of a specific hidden neuron for all $P$ stimuli. (B) The column vectors of $H$ can be plotted in a $P$-dimensional space that encodes each hidden neuron's tuning for all $P$ stimuli. (C) The statistics of preactivations of hidden neurons can be captured by the empirical (unnormalized) measure. The hidden representation kernel matrix is a linear function of such an empirical measure. (D-F) Mathematically, finding the optimal network reduces to determining the ground state of an effective $M$-neuron system whose interactions are governed by the input and output kernel matrices [Eq. \ref{['eq:loss_h']}], which is further equivalent to a mean-field problem where a single representative neuron interacts with the statistics of the neural activity in the network [Eq. \ref{['eq:MF_SingleNeuron']}].
  • Figure 4: Task-optimized ReLU network exhibits abstract representation for whitened and target-aligned inputs. (A) The training loss and the parallelism score of the hidden representation are plotted against the number of training epochs for the whitened input. The training is through a gradient descent algorithm. After training, the network performs the task perfectly with zero training loss and has an abstract hidden representation ($PS \rightarrow 1$). (B) The optimal hidden kernel predicted by theory ($K_{theory}$ given by Eq. \ref{['eq:opt_kernel']}) is aligned with the one found in numerical simulation ($K_{sim}$) for different output dimensions $d_Y$. (C) The parallelism score of the hidden representation after training as a function of the input-output alignment. See SI § 5 for the definition of input-output alignment. Each point in the plot represents a specific input geometry with a randomly sampled input kernel. The red dot indicates the point for the whitened input kernel. For inputs that are more aligned to the output than the whitened one, the $PS$ is close to $1$. (D-E) Modularity of the single-neuron tuning in the hidden layer is captured by the preactivation vector $\mathbf{h}$ and weight vector $\mathbf{v}$ to the output layer for each hidden neuron. Here $d_Y = 2$. (E) Each hidden neuron only has nonzero output weights to a single output unit, suggesting that different neurons in the hidden layer are used to "read out" different output labels. (F) The first three principal components of the neural preactivation space (the same space as Fig. 3B). The preactivation vectors of the hidden neurons concentrate along $P=4$ distinct directions, as predicted by Eq. \ref{['eq:h_opt']}.
  • Figure 5: Task-optimized ReLU network for multi-element classes ($n\geq 2$). (A) An example of input kernel with nonzero within-class correlation but zero between-class correlation (left). For illustration, this example assumes the same within-class correlation matrix $C$ for all classes but our theory also works for class-specific correlation [Eq.\ref{['eq:ortho_input_multi']}]. The theory predicts that the optimal hidden kernel $K_{theory}$ is proportional to the output kernel up to a positive shift (right). In particular, the hidden kernel has a block structure because all the data within each binary class has the same hidden representation [Eq.\ref{['eq:Optimal_h_multielement']}]. (B) Training loss and $PS$ are plotted against the number of training epochs. (C) The predicted hidden kernel ($K_{theory}$) always aligns with simulation ($K_{sim}$) when the number of data per class $n$ changes.
  • ...and 3 more figures