Table of Contents
Fetching ...

Self-Attention-Based Contextual Modulation Improves Neural System Identification

Isaac Lin, Tianye Wang, Shang Gao, Shiming Tang, Tai Sing Lee

TL;DR

Self-Attention-Based Contextual Modulation Improves Neural System Identification investigates how SA can augment CNNs to predict macaque V1 responses to natural images. The study demonstrates that a simple SA layer, when paired with a CNN and trained with an incremental learning protocol, improves both Pearson correlation of the predicted tuning curves and the ability to predict peak tuning, compared to a parameter-matched baseline. By factorizing contextual modulation into convolutions, SA, and a readout, the authors show that local receptive-field information dominates overall tuning, while surround information is essential for accurately predicting the strongest responses; incremental learning helps separate these contributions. The combination of SA and a fully connected readout yields complementary benefits, and incremental training highlights center-surround interactions that resemble early visual processing. These findings advance understanding of surround modulation in cortical computation and point to data-efficient strategies for neural prediction models.

Abstract

Convolutional neural networks (CNNs) have been shown to be state-of-the-art models for visual cortical neurons. Cortical neurons in the primary visual cortex are sensitive to contextual information mediated by extensive horizontal and feedback connections. Standard CNNs integrate global contextual information to model contextual modulation via two mechanisms: successive convolutions and a fully connected readout layer. In this paper, we find that self-attention (SA), an implementation of non-local network mechanisms, can improve neural response predictions over parameter-matched CNNs in two key metrics: tuning curve correlation and peak tuning. We introduce peak tuning as a metric to evaluate a model's ability to capture a neuron's top feature preference. We factorize networks to assess each context mechanism, revealing that information in the local receptive field is most important for modeling overall tuning, but surround information is critically necessary for characterizing the tuning peak. We find that self-attention can replace posterior spatial-integration convolutions when learned incrementally, and is further enhanced in the presence of a fully connected readout layer, suggesting that the two context mechanisms are complementary. Finally, we find that decomposing receptive field learning and contextual modulation learning in an incremental manner may be an effective and robust mechanism for learning surround-center interactions.

Self-Attention-Based Contextual Modulation Improves Neural System Identification

TL;DR

Self-Attention-Based Contextual Modulation Improves Neural System Identification investigates how SA can augment CNNs to predict macaque V1 responses to natural images. The study demonstrates that a simple SA layer, when paired with a CNN and trained with an incremental learning protocol, improves both Pearson correlation of the predicted tuning curves and the ability to predict peak tuning, compared to a parameter-matched baseline. By factorizing contextual modulation into convolutions, SA, and a readout, the authors show that local receptive-field information dominates overall tuning, while surround information is essential for accurately predicting the strongest responses; incremental learning helps separate these contributions. The combination of SA and a fully connected readout yields complementary benefits, and incremental training highlights center-surround interactions that resemble early visual processing. These findings advance understanding of surround modulation in cortical computation and point to data-efficient strategies for neural prediction models.

Abstract

Convolutional neural networks (CNNs) have been shown to be state-of-the-art models for visual cortical neurons. Cortical neurons in the primary visual cortex are sensitive to contextual information mediated by extensive horizontal and feedback connections. Standard CNNs integrate global contextual information to model contextual modulation via two mechanisms: successive convolutions and a fully connected readout layer. In this paper, we find that self-attention (SA), an implementation of non-local network mechanisms, can improve neural response predictions over parameter-matched CNNs in two key metrics: tuning curve correlation and peak tuning. We introduce peak tuning as a metric to evaluate a model's ability to capture a neuron's top feature preference. We factorize networks to assess each context mechanism, revealing that information in the local receptive field is most important for modeling overall tuning, but surround information is critically necessary for characterizing the tuning peak. We find that self-attention can replace posterior spatial-integration convolutions when learned incrementally, and is further enhanced in the presence of a fully connected readout layer, suggesting that the two context mechanisms are complementary. Finally, we find that decomposing receptive field learning and contextual modulation learning in an incremental manner may be an effective and robust mechanism for learning surround-center interactions.
Paper Structure (39 sections, 1 equation, 16 figures, 3 tables)

This paper contains 39 sections, 1 equation, 16 figures, 3 tables.

Figures (16)

  • Figure 1: Macaque neuronal response dataset. (a) shows a two-photon image with cells. (b) shows a feedforward CNN used to model neural response. (c) shows the response of one neuron to $50$k stimuli and the top 20 images that induced the strongest responses. On average, less than 0.5% of the images induce responses greater than half peak height. Each site contains around 300 neurons.
  • Figure 2: Models explored in this study. Models are constructed from two types of convolutional processing blocks (CPB): $\alpha$CPB and $\beta$CPB. $\alpha$CPB has a fixed convolution kernel size $= 5$ and max pooling kernel size $= 2$. $\beta$CPB takes an input convolution kernel size of $k$, and has no pooling layers. The two final layer readout modes are fully connected (FCL) and center hypercolumn only (CTL). Self-attention (SA) takes as input a boolean $\gamma$ that determines whether the value (V) vector is transformed; if $\gamma =$True then V is mapped, otherwise V is equal to the input. All models with SA utilize single-headed attention. (a) shows the feedforward CNN. (b) shows the feedforward CNN augmented with self-attention. (c) shows the receptive field CNN. (d) shows the receptive field CNN augmented with self-attention.
  • Figure 3: Incremental learning models. (a) shows the baseline receptive field CNN, equivalent to Figure \ref{['z_arch']}(c). (b) shows (a) augmented with SA and learned incrementally; the $\alpha$CPBs are taken from (a) and the remaining layers are learned. The $^*$ denotes slight modification from rf+sa-CNN, Figure \ref{['z_arch']}(d), namely $\gamma$ is changed to True. (c), (d) show the result of replacing the CTL in (b) with a FCL, and learned incrementally; (c) freezes only the center hypercolumn in the FCL (FC$_1$) whereas (d) allows the FCL to learn freely (FC$_2$). (c) and (d) have all other layers taken from (b). The $^*$ denotes slight modification from ff+sa-CNN, Figure \ref{['z_arch']}(b), namely $k$ in $\beta$CPB is changed to $k=1$. (Simul.) models are equivalent in architecture, except all blocks are learned.
  • Figure 4: Neuronal tuning curves of ff-CNN, ff+sa-cnn, and rf-CNN. Pearson correlation does not reflect peak tuning. Despite rf-CNN having the better correlation, it is clear that ff+sa-CNN is able to capture the peak significantly better, at the cost of a noisier overall tuning. Example shown is M1S1 neuron 238. See Appendix \ref{['a_pop_tc']} for population averages.
  • Figure 5: Average peak tuning indices for incrementally and simultaneously trained models. Top row: bar charts for M1S1. Bot row: bar charts for M2S1. Left col: average PT$_J$ values. Right col: average PT$_s$ values. Error bars are SEM.
  • ...and 11 more figures