Table of Contents
Fetching ...

Predictive auxiliary objectives in deep RL mimic learning in the brain

Ching Fang, Kimberly L Stachenfeld

TL;DR

This work demonstrates how representation learning in deep RL systems can provide an interpretable framework for modeling multi-region interactions in the brain and draws a connection between the auxiliary predictive model and hippocampus, an area thought to learn a predictive model to support memory-guided behavior.

Abstract

The ability to predict upcoming events has been hypothesized to comprise a key aspect of natural and machine cognition. This is supported by trends in deep reinforcement learning (RL), where self-supervised auxiliary objectives such as prediction are widely used to support representation learning and improve task performance. Here, we study the effects predictive auxiliary objectives have on representation learning across different modules of an RL system and how these mimic representational changes observed in the brain. We find that predictive objectives improve and stabilize learning particularly in resource-limited architectures, and we identify settings where longer predictive horizons better support representational transfer. Furthermore, we find that representational changes in this RL system bear a striking resemblance to changes in neural activity observed in the brain across various experiments. Specifically, we draw a connection between the auxiliary predictive model of the RL system and hippocampus, an area thought to learn a predictive model to support memory-guided behavior. We also connect the encoder network and the value learning network of the RL system to visual cortex and striatum in the brain, respectively. This work demonstrates how representation learning in deep RL systems can provide an interpretable framework for modeling multi-region interactions in the brain. The deep RL perspective taken here also suggests an additional role of the hippocampus in the brain -- that of an auxiliary learning system that benefits representation learning in other regions.

Predictive auxiliary objectives in deep RL mimic learning in the brain

TL;DR

This work demonstrates how representation learning in deep RL systems can provide an interpretable framework for modeling multi-region interactions in the brain and draws a connection between the auxiliary predictive model and hippocampus, an area thought to learn a predictive model to support memory-guided behavior.

Abstract

The ability to predict upcoming events has been hypothesized to comprise a key aspect of natural and machine cognition. This is supported by trends in deep reinforcement learning (RL), where self-supervised auxiliary objectives such as prediction are widely used to support representation learning and improve task performance. Here, we study the effects predictive auxiliary objectives have on representation learning across different modules of an RL system and how these mimic representational changes observed in the brain. We find that predictive objectives improve and stabilize learning particularly in resource-limited architectures, and we identify settings where longer predictive horizons better support representational transfer. Furthermore, we find that representational changes in this RL system bear a striking resemblance to changes in neural activity observed in the brain across various experiments. Specifically, we draw a connection between the auxiliary predictive model of the RL system and hippocampus, an area thought to learn a predictive model to support memory-guided behavior. We also connect the encoder network and the value learning network of the RL system to visual cortex and striatum in the brain, respectively. This work demonstrates how representation learning in deep RL systems can provide an interpretable framework for modeling multi-region interactions in the brain. The deep RL perspective taken here also suggests an additional role of the hippocampus in the brain -- that of an auxiliary learning system that benefits representation learning in other regions.
Paper Structure (16 sections, 11 figures, 3 tables)

This paper contains 16 sections, 11 figures, 3 tables.

Figures (11)

  • Figure 1: A deep RL framework to model multi-region computation. A. In the deep RL model we use, reward is provided as a scalar input $r$. Observations $o$ are 2D visual inputs fed into an encoder (green) that learns low-dimensional state space representations $z$. The encoder is a convolutional neural network. Representations $z$ are used to learn Q values via a MLP (blue); these Q values are used to select actions $a$. A predictive auxiliary objective (orange) is enforced by a separate MLP learning predictions from $z$.
  • Figure 2: Gridworld performance with predictive auxiliary tasks. A. The model is tested on gridworld task in a 8x8 arena. The agent must navigate to a hidden reward given random initial starting locations. B. Average episode score across training steps for models without auxiliary losses (blue), with only the negative sampling loss $\mathcal{L}_-$ (green), and with the full predictive loss $\mathcal{L}_{pred}$ (orange). The maximum score is 1 and $|z|=10$ (i.e. $z$ contains 10 units). In each step, the network is trained on one batch of replayed transitions (batch size is 64). All error bars are standard error mean over 45 random seeds. C. 3D PCA representations of latent states $z$ for the models in (B) (two random seeds). The latent states are colored by the quadrant of the arena they lie in. The quadrants (in order) are purple, pink, gray, brown. The goal location state is colored red. Gray lines represent the true connectivity between states. D. Diagram of the encoder network (red), learned latent state (gray), and value-learning network (blue). We vary $|z|$ (see E, F), as well as the encoder/decoder depths (Appendix \ref{['fig:appendix3']}AB). E. Average episode score at the end of learning (600 training steps) across $|z|$. F. Fraction of units in $z$ that are silent during the task, across $|z|$. G. Cosine similarity of two randomly sampled states throughout learning, $|z|=10$.
  • Figure 3: Effects of predictive auxiliary objectives across transfer learning scenarios. A. We test goal transfer by moving the goal location to a new state in task B. After training on task A, encoder weights are frozen and the value function is fine-tuned on task B. B. Average episode score across task A, then task B. All models shown use the predictive auxiliary loss, with the shade of each line corresponding to the magnitude of $\gamma$ in $\mathcal{L}_{pred}$ ($\gamma \in \{0.0, 0.25, 0.5, 0.8\}$, $|z|=17$). C. The episode score after $100$ training steps for each of the models in (B), as $|z|$ is increased. All models achieve maximum performance in task A. $30$ random seeds are run for each latent size. D. 3D PCA plots, for three models ($\gamma={0.0, 0.25, 0.5}$) with the same random seed. E. Pairwise cosine similarity values between the corner states of the arena for the model shown in (B). F. We test transition transfer by shuffling the connectivity between all states in task B. Freezing and fine-tuning are the same as in (A). G. Average episode score across task A, then task B. Here, $|z|=17$ and $\epsilon=0.4$-greedy policy during learning. In green is the model with only $\mathcal{L}_-$ as an auxiliary loss. H. Episode score after $150$ training steps for the model with only $\mathcal{L}_-$ (green) versus the model with $\mathcal{L}_{pred}$ for $\gamma=0.8$. On the x-axis, the policy $\epsilon$ used during training is varied, with $\epsilon=1.0$ corresponding to a fully random policy ($|z|=17$, all models achieve maximum performance on task A).
  • Figure 4: Representational changes in the predictive model are similar to those observed in the hippocampus. A. 2D foraging experiments are simulated as in the gridworld task from Fig 1-2. B. 2D receptive fields from top four $T$ units (columns) sorted by spatial information score skaggs1992information. Three random seeds are shown (rows). The model uses $\mathcal{L}_{pred}$ and $|z|=10$. White asterisk depicts reward. C. As in (B), but the model has no auxiliary objectives. D. Circular track experiments are simulated in a circular gridworld with $28$ states. Reward is in a random state for each seed and the agent is rewarded for running clockwise to the reward. E. Receptive fields of two example units in the $T$ network before (gray) and after (orange) learning. F. Histogram over the shift in receptive field peaks for units in $T$ over $15$ random seeds, where $|z|=24$. Positive values indicate shifts forward, and vice-versa for negative values. Black dotted line at $0$. Median of the histogram is $-0.034$. G. Histogram over the location of receptive field peaks for units in (F), with location centered around the reward site. Random shuffle (gray) control was made by randomly shuffling the weights of the $T$ network. Black dotted line at $0$. The model median is $-0.06$, while the random shuffle median is $-0.02$. H. We simulate a 5x5 alternating-T maze (see Appendix); center corridor in pink. I. Cosine similarity of $T$ population vector responses in the center corridor under left-turn versus right-turn conditions. X-axis depicts location in the center corridor. Data is from $20$ random seeds. Shown is the model without auxiliary objectives (blue) and the model with $\mathcal{L}_{pred}$ (orange). $T$ is randomly initialized for the model without an auxiliary objective.
  • Figure 5: Representational changes in the encoder model resemble recordings from visual cortex. A. Example sequence structure in the preference swap task of li2008unsupervisedli2010unsupervised, images numbered by seqeunce location. B. Example changes in IT neuron response to preferred images (red) and non-preferred images (blue) across exposure to new image transitions. C. Responses of two example units from the model with $\mathcal{L}_{pred}$. Arrows indicate response profile before and after experiencing swapped transitions. Red indicates the response to $P1, P2, P3$ states that were selected from the gridworld environment, while blue indicates the response to $N1, N2, N3$ states selected from the environment. D. Change in response difference between $(P1, N1)$, $(P2, N2)$, and $(P3, N3)$ over $10$ units. Each unit is a separate transition swap experiment. Shown is the model without any auxiliary objectives (blue) and the model with $\mathcal{L}_{pred}$ (orange). Asterisks indicate significance from a t-test comparing the means from both models. We additionally note that the means of both models are significantly different from 0. E. Linear track VR experiment used in poort2015learning. Vertical stripe corridors were rewarded but angled corridors were not. Animals experienced either condition at random following an approach corridor. F. Selectivity across the population before learning (gray) and after learning (orange). Selectivity was calculated as in poort2015learning, with negative and positive values corresponding to angled and vertical corridor preference, respectively. Asterisks indicate significance from one-tailed t-test ($t=-12.43$, $p=9\times10.0e-36$) G. Selectivity of individual units before and after learning for vertical condition (V), angled condition (A), or neither (N/A). Units are pooled across $15$ experiments.
  • ...and 6 more figures