Table of Contents
Fetching ...

Bridging Critical Gaps in Convergent Learning: How Representational Alignment Evolves Across Layers, Training, and Distribution Shifts

Chaitanya Kapoor, Sudhanshu Srivastava, Meenakshi Khosla

TL;DR

This work presents a large-scale audit of representational convergence across dozens of vision models, self-supervised networks, vision transformers, and language models. It compares three alignment families—linear regression, orthogonal Procrustes, and permutation/soft-matching—to quantify how similarly independently trained networks encode information, tracking these alignments across layers, across training time, and under distribution shifts. Key findings include strong early-layer convergence across architectures and metrics, only modest gains from more flexible mappings beyond rotation/reflection, and deeper-layer representations that diverge under out-of-distribution inputs, with parallel patterns observed in CNNs, ViTs, MoCo, and Pythia-based language models. The results illuminate a robust, depth-dependent, and largely input-statistics-driven portrait of convergent learning, with implications for neuroscience-brain modeling, model evaluation under distribution shifts, and architectural design choices that foster or limit representational alignment.

Abstract

Understanding convergent learning -- the degree to which independently trained neural systems -- whether multiple artificial networks or brains and models -- arrive at similar internal representations -- is crucial for both neuroscience and AI. Yet, the literature remains narrow in scope -- typically examining just a handful of models with one dataset, relying on one alignment metric, and evaluating networks at a single post-training checkpoint. We present a large-scale audit of convergent learning, spanning dozens of vision models and thousands of layer-pair comparisons, to close these long-standing gaps. First, we pit three alignment families against one another -- linear regression (affine-invariant), orthogonal Procrustes (rotation-/reflection-invariant), and permutation/soft-matching (unit-order-invariant). We find that orthogonal transformations align representations nearly as effectively as more flexible linear ones, and although permutation scores are lower, they significantly exceed chance, indicating a privileged representational basis. Tracking convergence throughout training further shows that nearly all eventual alignment crystallizes within the first epoch -- well before accuracy plateaus -- indicating it is largely driven by shared input statistics and architectural biases, not by the final task solution. Finally, when models are challenged with a battery of out-of-distribution images, early layers remain tightly aligned, whereas deeper layers diverge in proportion to the distribution shift. These findings fill critical gaps in our understanding of representational convergence, with implications for neuroscience and AI.

Bridging Critical Gaps in Convergent Learning: How Representational Alignment Evolves Across Layers, Training, and Distribution Shifts

TL;DR

This work presents a large-scale audit of representational convergence across dozens of vision models, self-supervised networks, vision transformers, and language models. It compares three alignment families—linear regression, orthogonal Procrustes, and permutation/soft-matching—to quantify how similarly independently trained networks encode information, tracking these alignments across layers, across training time, and under distribution shifts. Key findings include strong early-layer convergence across architectures and metrics, only modest gains from more flexible mappings beyond rotation/reflection, and deeper-layer representations that diverge under out-of-distribution inputs, with parallel patterns observed in CNNs, ViTs, MoCo, and Pythia-based language models. The results illuminate a robust, depth-dependent, and largely input-statistics-driven portrait of convergent learning, with implications for neuroscience-brain modeling, model evaluation under distribution shifts, and architectural design choices that foster or limit representational alignment.

Abstract

Understanding convergent learning -- the degree to which independently trained neural systems -- whether multiple artificial networks or brains and models -- arrive at similar internal representations -- is crucial for both neuroscience and AI. Yet, the literature remains narrow in scope -- typically examining just a handful of models with one dataset, relying on one alignment metric, and evaluating networks at a single post-training checkpoint. We present a large-scale audit of convergent learning, spanning dozens of vision models and thousands of layer-pair comparisons, to close these long-standing gaps. First, we pit three alignment families against one another -- linear regression (affine-invariant), orthogonal Procrustes (rotation-/reflection-invariant), and permutation/soft-matching (unit-order-invariant). We find that orthogonal transformations align representations nearly as effectively as more flexible linear ones, and although permutation scores are lower, they significantly exceed chance, indicating a privileged representational basis. Tracking convergence throughout training further shows that nearly all eventual alignment crystallizes within the first epoch -- well before accuracy plateaus -- indicating it is largely driven by shared input statistics and architectural biases, not by the final task solution. Finally, when models are challenged with a battery of out-of-distribution images, early layers remain tightly aligned, whereas deeper layers diverge in proportion to the distribution shift. These findings fill critical gaps in our understanding of representational convergence, with implications for neuroscience and AI.

Paper Structure

This paper contains 34 sections, 1 equation, 16 figures, 2 tables.

Figures (16)

  • Figure 1: Representational Convergence Across a Network Hierarchy. We plot the evolution of alignment scores (computed between different seeds of the same network architecture) across the network hierarchy for four vision network architectures trained ImageNet. A consistent downward trend across layers indicates decreasing representational convergence as networks deepen. Alignment consistently follows the order: Linear$>$Procrustes$>$Permutation, reflecting the progressively stricter nature of the metrics. Lighter shades of the same color denote alignment for random networks. Notably, Procrustes transformations align representations nearly as well as Linear transformations, suggesting that most variability is due to rotations rather than more complex transformations. Even Permutation scores---despite their strictness---achieve substantial alignment, indicating a strong one-to-one correspondence between neurons across seeds, which points to stable, convergent neuron-level representations. Error bars represent the standard deviation computed across $5$-fold cross-validation.
  • Figure 2: Inter-Model Comparisons. We consider all pairs of vision models, and for each pair, compute the alignment scores between every pair of layers using the orthogonal Procrustes (Top) and Soft-Matching (Bottom) metric trained on ImageNet. Gray line plots denote the maximum alignment value for each network over rows (right line) and columns (top line). A common trend that is observed here is the consistent relationships between layers of CNNs trained with different architectures.
  • Figure 3: Representational Alignment Through Training Evolution. We visualize the evolution of Procrustes alignment between network pairs during task optimization on ImageNet. Lighter shades indicate earlier epochs, progressively darkening with later epochs. The plots span from epoch $0$ (untrained) to epoch $10$, with task performance improving over time. Epoch progression can be inferred from the increasing task performance along the $x$-axis.
  • Figure 4: Procrustes Score-based Alignment Between Networks Sharing the Same Architecture but Trained With Different Random Seeds, Plotted as a Function of Layer Depth. Alignment is measured using within-distribution (WD) stimuli (ImageNet test set) and out-of-distribution (OOD) stimuli, with OOD values averaged across $17$ datasets. Error bars represent the standard error computed across the $(n = 17)$ OOD datasets.
  • Figure 5: Convergence on OOD Inputs.(a) Procrustes alignment vs. task performance of ResNet$50$ models on each of the $17$ datasets for the first convolutional layer (Left) and the penultimate (Right) layer. (b) Correlation between these variables as a function of network depth (normalized by each models depth).
  • ...and 11 more figures