Table of Contents
Fetching ...

Measuring and Controlling Solution Degeneracy across Task-Trained Recurrent Neural Networks

Ann Huang, Satpreet H. Singh, Flavio Martinelli, Kanaka Rajan

TL;DR

This work tackles the problem of solution degeneracy in independently trained task-trained RNNs by introducing a unified, multi-level framework that quantifies degeneracy across behavior, neural dynamics, and weights. Through a large-scale study across four neuroscience-relevant tasks and multiple control factors, the authors demonstrate contravariant and covariant relationships: higher task complexity and stronger feature learning tend to make neural dynamics more consistent while expanding weight diversity, whereas larger networks and structural regularization promote convergence across all levels. They validate the Contravariance Principle and provide practical guidance for tuning degeneracy to either reveal shared neural mechanisms or model individual variability observed in biology. The framework and findings offer a principled path toward more interpretable and biologically grounded RNN models, with implications for ensemble modeling and hypothesis testing in neuroscience.

Abstract

Task-trained recurrent neural networks (RNNs) are widely used in neuroscience and machine learning to model dynamical computations. To gain mechanistic insight into how neural systems solve tasks, prior work often reverse-engineers individual trained networks. However, different RNNs trained on the same task and achieving similar performance can exhibit strikingly different internal solutions, a phenomenon known as solution degeneracy. Here, we develop a unified framework to systematically quantify and control solution degeneracy across three levels: behavior, neural dynamics, and weight space. We apply this framework to 3,400 RNNs trained on four neuroscience-relevant tasks: flip-flop memory, sine wave generation, delayed discrimination, and path integration, while systematically varying task complexity, learning regime, network size, and regularization. We find that higher task complexity and stronger feature learning reduce degeneracy in neural dynamics but increase it in weight space, with mixed effects on behavior. In contrast, larger networks and structural regularization reduce degeneracy at all three levels. These findings empirically validate the Contravariance Principle and provide practical guidance for researchers seeking to tune the variability of RNN solutions, either to uncover shared neural mechanisms or to model the individual variability observed in biological systems. This work provides a principled framework for quantifying and controlling solution degeneracy in task-trained RNNs, offering new tools for building more interpretable and biologically grounded models of neural computation.

Measuring and Controlling Solution Degeneracy across Task-Trained Recurrent Neural Networks

TL;DR

This work tackles the problem of solution degeneracy in independently trained task-trained RNNs by introducing a unified, multi-level framework that quantifies degeneracy across behavior, neural dynamics, and weights. Through a large-scale study across four neuroscience-relevant tasks and multiple control factors, the authors demonstrate contravariant and covariant relationships: higher task complexity and stronger feature learning tend to make neural dynamics more consistent while expanding weight diversity, whereas larger networks and structural regularization promote convergence across all levels. They validate the Contravariance Principle and provide practical guidance for tuning degeneracy to either reveal shared neural mechanisms or model individual variability observed in biology. The framework and findings offer a principled path toward more interpretable and biologically grounded RNN models, with implications for ensemble modeling and hypothesis testing in neuroscience.

Abstract

Task-trained recurrent neural networks (RNNs) are widely used in neuroscience and machine learning to model dynamical computations. To gain mechanistic insight into how neural systems solve tasks, prior work often reverse-engineers individual trained networks. However, different RNNs trained on the same task and achieving similar performance can exhibit strikingly different internal solutions, a phenomenon known as solution degeneracy. Here, we develop a unified framework to systematically quantify and control solution degeneracy across three levels: behavior, neural dynamics, and weight space. We apply this framework to 3,400 RNNs trained on four neuroscience-relevant tasks: flip-flop memory, sine wave generation, delayed discrimination, and path integration, while systematically varying task complexity, learning regime, network size, and regularization. We find that higher task complexity and stronger feature learning reduce degeneracy in neural dynamics but increase it in weight space, with mixed effects on behavior. In contrast, larger networks and structural regularization reduce degeneracy at all three levels. These findings empirically validate the Contravariance Principle and provide practical guidance for researchers seeking to tune the variability of RNN solutions, either to uncover shared neural mechanisms or to model the individual variability observed in biological systems. This work provides a principled framework for quantifying and controlling solution degeneracy in task-trained RNNs, offering new tools for building more interpretable and biologically grounded models of neural computation.
Paper Structure (64 sections, 18 equations, 33 figures, 7 tables)

This paper contains 64 sections, 18 equations, 33 figures, 7 tables.

Figures (33)

  • Figure 1: Key factors shape degeneracy across behavior, dynamics, and weights. Schematic of our framework for analyzing solution degeneracy in task-trained RNNs. We evaluate how task complexity, learning regime, network size, and structural regularization influence degeneracy at three levels: behavior (network outputs), neural dynamics (state trajectories), and weight space (connectivity).
  • Figure 2: Our task suite spans memory, integration, pattern generation, and decision-making. Task schematics and representative network trajectories projected onto the top principal components are shown in (A)–(B). The four tasks are: N-Bit Flip-Flop: The network must remember the last nonzero input on each of $N$ independent channels. Delayed Discrimination: The network compares the magnitude of two pulses, separated by a variable delay, and outputs their sign difference. Sine Wave Generation: A static input specifies a target frequency, and the network generates the corresponding sine wave over time. Path Integration: The network integrates velocity inputs to track position in a bounded 2D or 3D arena (schematic shows 2D case).
  • Figure 3: Higher task complexity reduces dynamical and behavioral degeneracy, but increases weight degeneracy.(A) Two-dimensional MDS embedding of network dynamics shows that independently trained networks converge to more similar trajectories as task complexity increases. (B) Dynamical, (C) weight, and (D) behavioral degeneracy [temporal generalization] across 50 networks as a function of task complexity. Shaded area indicates $\pm 1$ standard error.
  • Figure 4: Increasing memory demand or adding auxiliary loss changes task complexity, which in turn modulates degeneracy. In the Delayed Discrimination task, both manipulations reduce dynamical and behavioral degeneracy [temporal generalization] while increasing weight degeneracy. The auxiliary loss also induces additional line attractors in the network’s dynamics, as shown in (C).
  • Figure 5: More complex tasks drive stronger feature learning in RNNs. Increased input–output dimensionality leads to higher weight-change norms ($||\Delta W||_F$) and lower kernel alignment (KA). Error bars indicate $\pm 1$ standard error.
  • ...and 28 more figures