Table of Contents
Fetching ...

Learned feature representations are biased by complexity, learning order, position, and more

Andrew Kyle Lampinen, Stephanie C. Y. Chan, Katherine Hermann

TL;DR

The paper examines how gradient-based learning biases internal representations toward certain features, even when multiple features contribute equally to computation. By training MLPs, Transformers, and CNNs on controlled, independent features and measuring variance explained via linear regressions, it shows systematic biases toward simpler, earlier-learned, or more prevalent features. It demonstrates causally that these biases influence interpretability tools (e.g., RSA, PCA visualizations) and downstream task performance, complicating cross-system comparisons between models and brains. The work highlights the need to account for learning dynamics and feature properties when inferring computations from representations and suggests directions for more faithful interpretability methods.

Abstract

Representation learning, and interpreting learned representations, are key areas of focus in machine learning and neuroscience. Both fields generally use representations as a means to understand or improve a system's computations. In this work, however, we explore surprising dissociations between representation and computation that may pose challenges for such efforts. We create datasets in which we attempt to match the computational role that different features play, while manipulating other properties of the features or the data. We train various deep learning architectures to compute these multiple abstract features about their inputs. We find that their learned feature representations are systematically biased towards representing some features more strongly than others, depending upon extraneous properties such as feature complexity, the order in which features are learned, and the distribution of features over the inputs. For example, features that are simpler to compute or learned first tend to be represented more strongly and densely than features that are more complex or learned later, even if all features are learned equally well. We also explore how these biases are affected by architectures, optimizers, and training regimes (e.g., in transformers, features decoded earlier in the output sequence also tend to be represented more strongly). Our results help to characterize the inductive biases of gradient-based representation learning. We then illustrate the downstream effects of these biases on various commonly-used methods for analyzing or intervening on representations. These results highlight a key challenge for interpretability $-$ or for comparing the representations of models and brains $-$ disentangling extraneous biases from the computationally important aspects of a system's internal representations.

Learned feature representations are biased by complexity, learning order, position, and more

TL;DR

The paper examines how gradient-based learning biases internal representations toward certain features, even when multiple features contribute equally to computation. By training MLPs, Transformers, and CNNs on controlled, independent features and measuring variance explained via linear regressions, it shows systematic biases toward simpler, earlier-learned, or more prevalent features. It demonstrates causally that these biases influence interpretability tools (e.g., RSA, PCA visualizations) and downstream task performance, complicating cross-system comparisons between models and brains. The work highlights the need to account for learning dynamics and feature properties when inferring computations from representations and suggests directions for more faithful interpretability methods.

Abstract

Representation learning, and interpreting learned representations, are key areas of focus in machine learning and neuroscience. Both fields generally use representations as a means to understand or improve a system's computations. In this work, however, we explore surprising dissociations between representation and computation that may pose challenges for such efforts. We create datasets in which we attempt to match the computational role that different features play, while manipulating other properties of the features or the data. We train various deep learning architectures to compute these multiple abstract features about their inputs. We find that their learned feature representations are systematically biased towards representing some features more strongly than others, depending upon extraneous properties such as feature complexity, the order in which features are learned, and the distribution of features over the inputs. For example, features that are simpler to compute or learned first tend to be represented more strongly and densely than features that are more complex or learned later, even if all features are learned equally well. We also explore how these biases are affected by architectures, optimizers, and training regimes (e.g., in transformers, features decoded earlier in the output sequence also tend to be represented more strongly). Our results help to characterize the inductive biases of gradient-based representation learning. We then illustrate the downstream effects of these biases on various commonly-used methods for analyzing or intervening on representations. These results highlight a key challenge for interpretability or for comparing the representations of models and brains disentangling extraneous biases from the computationally important aspects of a system's internal representations.
Paper Structure (47 sections, 3 equations, 43 figures, 1 table)

This paper contains 47 sections, 3 equations, 43 figures, 1 table.

Figures (43)

  • Figure 1: An overview of our approach. (top left) We train networks to compute multiple features of an input, on controlled datasets where we systematically manipulate various properties of the input and target output distribution; for example, to compute easier- and harder-to-compute features of the input. (top center) We study how these properties bias the feature representations the model learns---for example, towards representing easier-to-compute features more strongly. (top right) We explore the impacts of those biases on downstream areas like interpretability and cognitive neuroscience. (bottom) We consider these phenomena across a range of experiment domains, including boolean functions and various language and vision tasks, using standard architectures for each domain.
  • Figure 2: Test accuracy (left) and representation variance (right) over learning in an MLP trained to output easy and hard features. The easy feature (reading out a single input) is learned quite rapidly, but the hard feature (sum of 4 inputs mod 2) is learned more slowly; nevertheless, by the end of training the model achieves perfect accuracy for both features on a held-out test set. Despite this, the easy feature occupies substantially more of the penultimate layer representation variance than the hard feature. It is particularly interesting to note that long after the easy feature has been mastered, the representation variance attributable to it continues to grow, especially when the harder feature performance is most rapidly improving. (Representation variance is plotted normalized by the total representation variance at the end of training, to show the gradual increase over time. See Appendix \ref{['app_fig:mlp:unnormalized_R2_curves']} for an unnormalized version. Lines are averages across 10 seeds, intervals are 95%-CIs.)
  • Figure 3: The first two principal components of the penultimate representations of an MLP after learning the easy (linear) and hard (sum % 2) features (see Fig. \ref{['fig:mlp:base_easy_hard']}). The first PCs of the representations are structured into two clusters defined entirely by the value of the easy feature (colors); the hard feature (shapes) does not substantially impact the top PCs.
  • Figure 4: Penultimate representation variance explained by the easy (linear) and hard (sum of 4 inputs mod 2) features, across different training paradigms. Pretraining to output the easy feature results in very similar representations to training the features simultaneously, since the easy feature is learned first anyway. Pretraining the hard feature closes the gap somewhat, but the easy feature still dominates. (Bars are averages across 10 seeds, intervals are 95%-CIs.)
  • Figure 5: Penultimate representation variance explained at the end of training by the easy (linear) and hard features, as well as all the patterns over the inputs relevant to the hard feature, and across different training paradigms.
  • ...and 38 more figures