Table of Contents
Fetching ...

Insights on representational similarity in neural networks with canonical correlation

Ari S. Morcos, Maithra Raghu, Samy Bengio

Abstract

Comparing different neural network representations and determining how representations evolve over time remain challenging open questions in our understanding of the function of neural networks. Comparing representations in neural networks is fundamentally difficult as the structure of representations varies greatly, even across groups of networks trained on identical tasks, and over the course of training. Here, we develop projection weighted CCA (Canonical Correlation Analysis) as a tool for understanding neural networks, building off of SVCCA, a recently proposed method (Raghu et al., 2017). We first improve the core method, showing how to differentiate between signal and noise, and then apply this technique to compare across a group of CNNs, demonstrating that networks which generalize converge to more similar representations than networks which memorize, that wider networks converge to more similar solutions than narrow networks, and that trained networks with identical topology but different learning rates converge to distinct clusters with diverse representations. We also investigate the representational dynamics of RNNs, across both training and sequential timesteps, finding that RNNs converge in a bottom-up pattern over the course of training and that the hidden state is highly variable over the course of a sequence, even when accounting for linear transforms. Together, these results provide new insights into the function of CNNs and RNNs, and demonstrate the utility of using CCA to understand representations.

Insights on representational similarity in neural networks with canonical correlation

Abstract

Comparing different neural network representations and determining how representations evolve over time remain challenging open questions in our understanding of the function of neural networks. Comparing representations in neural networks is fundamentally difficult as the structure of representations varies greatly, even across groups of networks trained on identical tasks, and over the course of training. Here, we develop projection weighted CCA (Canonical Correlation Analysis) as a tool for understanding neural networks, building off of SVCCA, a recently proposed method (Raghu et al., 2017). We first improve the core method, showing how to differentiate between signal and noise, and then apply this technique to compare across a group of CNNs, demonstrating that networks which generalize converge to more similar representations than networks which memorize, that wider networks converge to more similar solutions than narrow networks, and that trained networks with identical topology but different learning rates converge to distinct clusters with diverse representations. We also investigate the representational dynamics of RNNs, across both training and sequential timesteps, finding that RNNs converge in a bottom-up pattern over the course of training and that the hidden state is highly variable over the course of a sequence, even when accounting for linear transforms. Together, these results provide new insights into the function of CNNs and RNNs, and demonstrate the utility of using CCA to understand representations.

Paper Structure

This paper contains 23 sections, 10 equations, 16 figures.

Figures (16)

  • Figure 1: CCA distinguishes between stable and unstable parts of the representation over the course of training. Sorted CCA coefficients ($\rho^{(i)}_t$) comparing representations between layer $L$ at times $t$ through training with its representation at the final timestep $T$ for CNNs trained on CIFAR-10 (a), and RNNs trained on PTB (b) and WikiText-2 (c). For all of these networks, at time $t_0 < T$ (indicated in title), the performance converges to match final performance (see Figure \ref{['fig:model_accs']}). However, many $\rho^{(i)}_t$ are unconverged, corresponding to unnecessary parts of the representation (noise). To distinguish between the signal and noise portions of the representation, we apply CCA between $L$ at timestep $t_{early}$ early in training, and $L$ at timestep $T/2$ to get $\rho_{T/2}$. We take the $100$ top converged vectors (according to $\rho_{T/2}$) to form $S$, and the $100$ least converged vectors to form $B$. We then compute CCA similarity between $S$ and $L$ at time $t > t_{early}$, and similarly for $B$. $S$ remains stable through training (signal), while $B$ rapidly becomes uncorrelated (d-f). Note that the sudden spike at $T/2$ in the unstable representation is because it is chosen to be the least correlated with step $T/2$.
  • Figure 2: Projection weighted (PWCCA) vs. SVCCA vs. unweighted mean Unweighted mean (blue) and projection weighted mean (red) were used to compare synthetic ground truth signal and uncommon (noise) structure, each of fixed dimensionality. As the signal to noise ratio decreases, the unweighted mean underestimates the shared structure, while the projection weighted mean remains largely robust. SVCCA performs better than the unweighted mean but less well than the projection weighting.
  • Figure 3: Generalizing networks converge to more similar solutions than memorizing networks. Groups of 5 networks were trained on CIFAR-10 with either true labels (generalizing) or a fixed random permutation of the labels (memorizing). The pairwise CCA distance was then compared within each group and between generalizing and memorizing networks (inter) for each layer, based on the training data, and the projection weighted CCA coefficient (with thresholding to remove low variance noise.) While both categories converged to similar solutions in early layers, likely reflecting convergent edge detectors, etc., generalizing networks converge to significantly more similar solutions in later layers. At the softmax, sets of both generalizing and memorizing networks converged to nearly identical solutions, as all networks achieved near-zero training loss. Error bars represent mean $\pm$ std weighted mean CCA distance across pairwise comparisons.
  • Figure 4: Larger networks converge to more similar solutions. Groups of 5 networks with different random initializations were trained on CIFAR-10. Pairwise CCA distance was computed for members of each group. Groups of larger networks converged to more similar solutions than groups of smaller networks (a). Test accuracy was highly correlated with degree of convergent similarity, as measured by CCA distance (b).
  • Figure 5: CCA reveals clusters of converged solutions across networks with different random initializations and learning rates. 200 networks with identical topology and varying learning rates were trained on CIFAR-10. CCA distance between the eighth layer of each pair of networks was computed, revealing five distinct subgroups of networks (a). These five subgroups align almost perfectly with the subgroups discovered in morcos2018singledirections (b; colors correspond to bars in a), despite the fact that the clusters in morcos2018singledirections were generated using robustness to cumulative ablation, an entirely separate metric.
  • ...and 11 more figures