Table of Contents
Fetching ...

RNNs perform task computations by dynamically warping neural representations

Arthur Pellegrino, Angus Chadwick

TL;DR

The paper introduces a Riemannian-geometric framework that links the topology of low-dimensional input manifolds to the geometry of the dynamical neural state manifolds via a pullback metric derived from adjoint dynamics. It demonstrates that recurrent networks solve time-dependent tasks by dynamically warping their internal representations, selectively compressing irrelevant inputs while aligning relevant variables with readouts. Across static and sequential tasks, the authors show that computations manifest as context-dependent warping of manifolds (e.g., circles to warped grids, 3D manifolds to line attractors, and hyper-tori for memory). This approach provides a principled, mathematically tractable way to interpret how geometry encodes computation in RNNs and offers a foundation for analyzing data-driven dynamical systems.

Abstract

Analysing how neural networks represent data features in their activations can help interpret how they perform tasks. Hence, a long line of work has focused on mathematically characterising the geometry of such "neural representations." In parallel, machine learning has seen a surge of interest in understanding how dynamical systems perform computations on time-varying input data. Yet, the link between computation-through-dynamics and representational geometry remains poorly understood. Here, we hypothesise that recurrent neural networks (RNNs) perform computations by dynamically warping their representations of task variables. To test this hypothesis, we develop a Riemannian geometric framework that enables the derivation of the manifold topology and geometry of a dynamical system from the manifold of its inputs. By characterising the time-varying geometry of RNNs, we show that dynamic warping is a fundamental feature of their computations.

RNNs perform task computations by dynamically warping neural representations

TL;DR

The paper introduces a Riemannian-geometric framework that links the topology of low-dimensional input manifolds to the geometry of the dynamical neural state manifolds via a pullback metric derived from adjoint dynamics. It demonstrates that recurrent networks solve time-dependent tasks by dynamically warping their internal representations, selectively compressing irrelevant inputs while aligning relevant variables with readouts. Across static and sequential tasks, the authors show that computations manifest as context-dependent warping of manifolds (e.g., circles to warped grids, 3D manifolds to line attractors, and hyper-tori for memory). This approach provides a principled, mathematically tractable way to interpret how geometry encodes computation in RNNs and offers a foundation for analyzing data-driven dynamical systems.

Abstract

Analysing how neural networks represent data features in their activations can help interpret how they perform tasks. Hence, a long line of work has focused on mathematically characterising the geometry of such "neural representations." In parallel, machine learning has seen a surge of interest in understanding how dynamical systems perform computations on time-varying input data. Yet, the link between computation-through-dynamics and representational geometry remains poorly understood. Here, we hypothesise that recurrent neural networks (RNNs) perform computations by dynamically warping their representations of task variables. To test this hypothesis, we develop a Riemannian geometric framework that enables the derivation of the manifold topology and geometry of a dynamical system from the manifold of its inputs. By characterising the time-varying geometry of RNNs, we show that dynamic warping is a fundamental feature of their computations.

Paper Structure

This paper contains 16 sections, 12 theorems, 41 equations, 9 figures, 1 table.

Key Result

Theorem 3.1

Consider a dynamical system as above, and let $\mathbf{u}\in\mathcal{M}$ where $\mathcal{M}$ is an $m$-dimensional manifold. Then $\mathbf{x}(t)\in \mathcal{N}$ where $\mathcal{N}=P(\mathcal{M}\times \mathbb{R})$ and $P$ is a projection map.

Figures (9)

  • Figure 1: Dynamical systems receiving inputs on a low-dimensional manifold of functions are constrained to low-dimensional manifold of states. The Riemannian geometry of this manifold can provide insights into dynamical representations and computations.
  • Figure 2: The pullback metric captures features of task-computation.a. Network with one hidden layer and tanh nonlinearity trained to map $\mathbf{x}=[\cos(\theta), \sin(\theta)]$ to $\mathbf{y}=\mathds{1}_{\theta\in[0, \pi)}-\mathds{1}_{\theta\in (\pi, 2\pi)}$. b. Activation of the three hidden units of the network in response to the inputs. c. Metric learned by the network --- represented as gridlines on the input manifold --- illustrating that the manifold has been warped around the class boundary.
  • Figure 3: Dynamical systems' states lie on manifolds whose geometry capture task computations.a. Schematic of how a manifold of time-varying inputs generates a dynamical system manifold. Each point $\mathbf{u}$ on the input manifold $\mathcal{M}$ is a time-varying function. The integral function $\varphi$ maps this input function to the time-varying solution to the dynamical system. The tangent space of the dynamical system manifold is thus spanned by two vectors corresponding to small changes in the input parameters $\boldsymbol{\kappa}$ (given by the adjoint $\mathbf{a}$) and small changes in the time point $t$ (given by $\mathbf{f}$). b. E-I network of decision making, whose constant-in-time inputs lie on a line manifold. c. The metric captures the transition from the initial state to a line of attractor fixed points. d. The metric is the inner product between the tangent vectors. e. Single-neuron activity across different trajectories.
  • Figure 4: Neural manifold warping of irrelevant inputs during contextual evidence integration.a. Task schematic: the task consists of two noisy inputs at different expected magnitude. Depending on the contextual cue, the network has to output the sign of the relevant input. Right: example outputs under the same noise realisation across different magnitudes of the relevant input. b. Because the input lies on a $2$D manifold, neural activity lies on a $3$D manifold (here shown in the space of three neurons). Further, there are two manifolds, one for each context. These manifolds are bounded by planes of attractor fixed points (see panel e). c. Changes in the relevant (but not irrelevant) input lead to changes in the decoder output at late (but not early) time points. Error bars are across the irrelevant input. d. This ensures that the output is not affected by variability in the irrelevant input. e.$2$D time-slices of the manifold. The metric goes from being everywhere diagonal with equal diagonal entries at early time points to having a near-zero time component (because the network has converged to an attractor state) and a dominant input space component that depends on the context. f. Geodesic gridlines under the pull-back of the metric at two different time points highlights that space becomes stretched near the decision boundary along the relevant input g. Eigenvalue of the metric (normalised to $1$, error bars across the irrelevant input). The manifold becomes pseudo-Riemannian at large times, such that only the largest eigenvalue (corresponding to the relevant input) is large. h. Alignment of the largest eigenvector of the weight updates with each basis vector of the tangent space. Over time the largest eigenvector becomes more aligned with the relevant input. i. The change in activation of neurons following a change along the each basis vector of the tangent space follow a Gaussian distribution. The covariance of this distribution is correlated/anti-correlated between the time- and relevant-input but not between the time- and irrelevant-input.
  • Figure 5: Task where a network has to remember two sequential inputs and then output them after a variable delay. Possible task solutions: i) dynamically warping the torus to compress the irrelevant input's representation or ii) realigning the torus' relevant input encoding direction to the decoder subspace while preserving a fixed intrinsic geometry.
  • ...and 4 more figures

Theorems & Definitions (22)

  • Theorem 3.1
  • Corollary 3.2
  • Corollary 3.3
  • Theorem 3.4
  • Corollary 3.5
  • Proposition S7.1
  • Proposition S7.2
  • Definition S7.1
  • Definition S7.2
  • Definition S7.3
  • ...and 12 more