Table of Contents
Fetching ...

Geometry of naturalistic object representations in recurrent neural network models of working memory

Xiaoxuan Lei, Takuya Ito, Pouya Bashivan

TL;DR

These findings indicate that goal-driven RNNs employ chronological memory subspaces to track information over short time spans, enabling testable predictions with neural data.

Abstract

Working memory is a central cognitive ability crucial for intelligent decision-making. Recent experimental and computational work studying working memory has primarily used categorical (i.e., one-hot) inputs, rather than ecologically relevant, multidimensional naturalistic ones. Moreover, studies have primarily investigated working memory during single or few cognitive tasks. As a result, an understanding of how naturalistic object information is maintained in working memory in neural networks is still lacking. To bridge this gap, we developed sensory-cognitive models, comprising a convolutional neural network (CNN) coupled with a recurrent neural network (RNN), and trained them on nine distinct N-back tasks using naturalistic stimuli. By examining the RNN's latent space, we found that: (1) Multi-task RNNs represent both task-relevant and irrelevant information simultaneously while performing tasks; (2) The latent subspaces used to maintain specific object properties in vanilla RNNs are largely shared across tasks, but highly task-specific in gated RNNs such as GRU and LSTM; (3) Surprisingly, RNNs embed objects in new representational spaces in which individual object features are less orthogonalized relative to the perceptual space; (4) The transformation of working memory encodings (i.e., embedding of visual inputs in the RNN latent space) into memory was shared across stimuli, yet the transformations governing the retention of a memory in the face of incoming distractor stimuli were distinct across time. Our findings indicate that goal-driven RNNs employ chronological memory subspaces to track information over short time spans, enabling testable predictions with neural data.

Geometry of naturalistic object representations in recurrent neural network models of working memory

TL;DR

These findings indicate that goal-driven RNNs employ chronological memory subspaces to track information over short time spans, enabling testable predictions with neural data.

Abstract

Working memory is a central cognitive ability crucial for intelligent decision-making. Recent experimental and computational work studying working memory has primarily used categorical (i.e., one-hot) inputs, rather than ecologically relevant, multidimensional naturalistic ones. Moreover, studies have primarily investigated working memory during single or few cognitive tasks. As a result, an understanding of how naturalistic object information is maintained in working memory in neural networks is still lacking. To bridge this gap, we developed sensory-cognitive models, comprising a convolutional neural network (CNN) coupled with a recurrent neural network (RNN), and trained them on nine distinct N-back tasks using naturalistic stimuli. By examining the RNN's latent space, we found that: (1) Multi-task RNNs represent both task-relevant and irrelevant information simultaneously while performing tasks; (2) The latent subspaces used to maintain specific object properties in vanilla RNNs are largely shared across tasks, but highly task-specific in gated RNNs such as GRU and LSTM; (3) Surprisingly, RNNs embed objects in new representational spaces in which individual object features are less orthogonalized relative to the perceptual space; (4) The transformation of working memory encodings (i.e., embedding of visual inputs in the RNN latent space) into memory was shared across stimuli, yet the transformations governing the retention of a memory in the face of incoming distractor stimuli were distinct across time. Our findings indicate that goal-driven RNNs employ chronological memory subspaces to track information over short time spans, enabling testable predictions with neural data.

Paper Structure

This paper contains 21 sections, 5 equations, 11 figures.

Figures (11)

  • Figure 1: Tasks and Models:a) Example of a 2-back category task. Each object's category is compared with the category of the object seen two frames prior. b) The suite of n-back tasks considered in the study. c) The sensory-cognitive model architecture. d) A schematic showing the latent subspaces for category, identity, and locations in the perceptual, encoding, and memory subspaces. Left: Stimuli are encoded in high dimensional latent space of the vision model (CNN). Each object property is encoded in a high dimensional latent subspace of this model; Right: RNN model represents each object property in its encoding latent subspace and retains some or all of the properties within its memory subspaces at later time points.
  • Figure 2: Representation of task-relevant/-irrelevant object properties:(a) Decoding generalization accuracy for each object property is displayed across tasks and operating modes for vanilla RNN and GRU. Rows and columns of $3 \times 3$ matrices correspond to the $N$-back task on which the decoders are fitted and tested on respectively. Matrix columns correspond to particular decoders denoted by $D_{k,F}$ ($k\in\{1,2,3\}$, $F\in\{L, I, C\}$) (indicating which task and decoding feature the decoder was fitted on), while matrix rows correspond to the object property of the task the decoder was tested on. (b) Validation accuracy of decoders trained on RNN latent space activations from the first time step of each trial to predict different object properties. Each column represents the object property the decoder was trained on, while each row corresponds to a model. c) Quantification of the validation accuracy (within the same task, indicated in purple) and generalization accuracy (across tasks with different task-relevant features, indicated in yellow) across all model architectures.
  • Figure 3: Orthogonalization: a) A schematic of two hypothetical object spaces in 3D. $r_{i,j}$ represents the angle formed by the decision hyperplanes that separate feature value $i$ and $j$ from each other. Top: non-orthogonalized representation; Bottom: orthogonalized representation. b) Upper panel: Normalized orthogonalization index, for both perceptual and encoding spaces respectively (denoted as $O(Perceptual)$ and $O(Encoding)$). In most models, a less orthogonalized representation of feature values emerges in the RNN encoding space compared to the perceptual space (CNN output). Lower panel: Statistical comparison of the relative orthogonalization levels between the perceptual and encoding spaces. A two-sample t-test was performed to assess differences between the distributions of orthogonalization indices in the perceptual space and the encoding space.
  • Figure 4: RNN dynamics during n-back task a) schematic of the 3-back task for a trial of 6 inputs. Model encodes each observed object in its respective Encoding Space denoted as $E_{(i,j)}$ (diagonal frames with yellow borders). For each stimulus, various object properties are retained over time in their respective Memory Space denoted as $MS$. On executive steps (frames with red borders) model produces a response according to the memory of the stimulus and the newly observed stimulus at that time. b) Decoding accuracy for predicting object identity at different time steps where the decoder is fit to data from the encoding step of a MTMF GRU during 1/2/3-back identity tasks. Red box indicates the executive steps. c) For each model type, we measured the generalization accuracy on executive (left boxplot) and non-executive (right boxplot) steps. d) Decoding accuracy for decoders trained and tested on the same $E_{i,i}$ space (validation) or tested on other $E_{j,j}, j\neq i$ spaces. e) Schematic of the three hypotheses. f) Left: Schematic of the two latent space transformations. Structured transformation preserves the topology (i.e. the transformation can be captured solely by a common scaling factor and a rotation matrix). Unstructured transformation: does not preserve the topology. Right: Decoding accuracy for fitted decoders (solid line) and reconstructed decoders (dotted line) using the rotation matrix $R_{(i,i)}$from the Procrustes analysis. The small accuracy gap between fitted and reconstructed decoders suggests a structured transformation. g) Decoding accuracy of the reconstructed decoder when the original rotation matrix is substituted with another (indicated by the x-axis labels). Rows and columns corresponds to object properties and MTMF network architectures respectively.
  • Figure A1: Stimuli and Model performance a) rendered stimuli examples from Shapenet. b) 9 task variations of N-back constructed from different choices of task-relevant features ($L,I,C$) and $N$ (1,2,3) index. c) Model performance on train, validation novel angle and validation novel object datasets. Three architectures are tested with various number of hidden size, with the number of trainable parameters indicated on x-axis.
  • ...and 6 more figures