Table of Contents
Fetching ...

Metric Learning Encoding Models: A Multivariate Framework for Interpreting Neural Representations

Louis Jalouzot, Christophe Pallier, Emmanuel Chemla, Yair Lakretz

TL;DR

Metric Learning Encoding Models (MLEMs) recast neural representation interpretation as a metric-learning problem over a space of theoretical features. By learning a symmetric positive definite matrix $W$ that defines a weighted distance on feature distances, MLEM captures both feature effects and their interactions, optimizing a rank-based alignment with neural distances. Across simulations and LLM-derived data, MLEM demonstrates superior weight recovery, robustness to noise, and faster convergence than FR-RSA-I, while yielding interpretable, layer-wise geometric structures in representations. The framework is modality-agnostic, scalable via batch-based training, and supported by open-source software for broad application to neuroscience and AI settings.

Abstract

Understanding how explicit theoretical features are encoded in opaque neural systems is a central challenge now common to neuroscience and AI. We introduce Metric Learning Encoding Models (MLEMs) to address this challenge most directly as a metric learning problem: we fit the distance in the space of theoretical features to match the distance in neural space. Our framework improves on univariate encoding and decoding methods by building on second-order isomorphism methods, such as Representational Similarity Analysis, and extends them by learning a metric that efficiently models feature as well as interactions between them. The effectiveness of MLEM is validated through two sets of simulations. First, MLEMs recover ground-truth importance features in synthetic datasets better than state-of-the-art methods, such as Feature Reweighted RSA (FR-RSA). Second, we deploy MLEMs on real language data, where they show stronger robustness to noise in calculating the importance of linguistic features (gender, tense, etc.). MLEMs are applicable to any domains where theoretical features can be identified, such as language, vision, audition, etc. We release optimized code applicable to measure feature importance in the representations of any artificial neural networks or empirical neural data at https://github.com/LouisJalouzot/MLEM.

Metric Learning Encoding Models: A Multivariate Framework for Interpreting Neural Representations

TL;DR

Metric Learning Encoding Models (MLEMs) recast neural representation interpretation as a metric-learning problem over a space of theoretical features. By learning a symmetric positive definite matrix that defines a weighted distance on feature distances, MLEM captures both feature effects and their interactions, optimizing a rank-based alignment with neural distances. Across simulations and LLM-derived data, MLEM demonstrates superior weight recovery, robustness to noise, and faster convergence than FR-RSA-I, while yielding interpretable, layer-wise geometric structures in representations. The framework is modality-agnostic, scalable via batch-based training, and supported by open-source software for broad application to neuroscience and AI settings.

Abstract

Understanding how explicit theoretical features are encoded in opaque neural systems is a central challenge now common to neuroscience and AI. We introduce Metric Learning Encoding Models (MLEMs) to address this challenge most directly as a metric learning problem: we fit the distance in the space of theoretical features to match the distance in neural space. Our framework improves on univariate encoding and decoding methods by building on second-order isomorphism methods, such as Representational Similarity Analysis, and extends them by learning a metric that efficiently models feature as well as interactions between them. The effectiveness of MLEM is validated through two sets of simulations. First, MLEMs recover ground-truth importance features in synthetic datasets better than state-of-the-art methods, such as Feature Reweighted RSA (FR-RSA). Second, we deploy MLEMs on real language data, where they show stronger robustness to noise in calculating the importance of linguistic features (gender, tense, etc.). MLEMs are applicable to any domains where theoretical features can be identified, such as language, vision, audition, etc. We release optimized code applicable to measure feature importance in the representations of any artificial neural networks or empirical neural data at https://github.com/LouisJalouzot/MLEM.
Paper Structure (46 sections, 16 equations, 20 figures, 2 tables)

This paper contains 46 sections, 16 equations, 20 figures, 2 tables.

Figures (20)

  • Figure 1: Overview of the Metric Learning Encoding Models (MLEMs) Approach: an example from natural language A: Consider a set of $n$ stimuli (rows; sentences) and $m$ theoretical features (columns; grammatical number, tense, gender, ...) along with their values for each stimulus. In this example, the features are categorical and feature distance are defined as whether a pair of stimuli ($s_i, s_j$) share a given feature $f_k$: $\mathbb{1}\left[f_k(s_i) \neq f_k(s_j)\right]$. Accordingly, a Representational Dissimilarity Matrix (RDM) is computed for each of the $m$ features ($D^\text{Number}, D^\text{Tense}, D^\text{Gender}, ... \in \mathbb{R}_+^{n \times n}$). Elements of the matrix thus indicate whether pairs of stimuli share a given feature. Finally, for each pair of stimuli, let $D^F_{ij} \in \mathbb{R}^m$ be a vector, composed of the $(i, j)$ elements from all RDMs (bottom panel). $D^F_{ij}$ summarizes the theoretical distance between two stimuli, with respect to all features. B: On the other hand, consider $n$ high-dimensional neural representations of the stimuli (e.g. embeddings of LLMs or fMRI images, here represented in 2D). From these representations, one can compute an RDM of neural distances$D^N$ (e.g., based on Euclidean distance). C: The core metric learning component of MLEM involves learning a weighted norm of the feature distance vectors $\|D^F_{ij}\|_W$ to approximate the empirical neural distances $D_{ij}^N$. This approximation is optimized to maximize the Spearman correlation $\rho$ between the theoretical and neural distances. The norm is parametrized with a learned Symmetric Positive Definite (SPD) matrix $W\in\mathbb{S}^{++}_m$ so as to be a valid metric (distance function). Crucially, the off-diagonal coefficients of $W$ allow for direct modeling of interactions between features (e.g., between grammatical number and gender) since they may have a combined effect on neural representations. D: Permutation Feature Importance is applied to a fitted MLEM model to assess the importance of each feature and their interactions in predicting neural distances. In practice, the importance is the average decrease in Spearman correlation $\rho$ when the feature values are permuted.
  • Figure 2: MLEM is more accurate and robust to noise than FR-RSA-I. Frobenius distance at different noise levels between estimated weights and (a) ground-truth weights in synthetic data (accuracy), and (b) estimation at noise level 0 on LLM embeddings (robustness). MLEM (blue) consistently achieves lower error than FR-RSA-I (orange). In both plots, the full line represents the average and the shaded area represents the standard deviation across 5 runs.
  • Figure 3: MLEM converges faster than FR-RSA-I. The plots compare the number of training steps (epochs) of MLEM (blue) and FR-RSA-I (orange). (a) On simulated datasets, MLEM converges faster across all noise levels. (b) On the relative clause dataset with LLM embeddings, MLEM converges faster for each of the 12 layers. This demonstrates MLEM's superior computational efficiency. The full line represents the average and the shaded area represents the standard deviation across 5 runs.
  • Figure 4: MLEM and FR-RSA-I identify similar profiles of feature importance across LLM layers. The plots show the top feature importances for MLEM (left) and FR-RSA-I (right) across the 12 layers of the LLM. Each line represents a feature or interaction, and its y-value indicates its importance at a given layer. Both methods reveal a similar pattern: features related to sentence structure, such as "Relative Clause type" and "Attachment site", gain importance in the middle layers, which is consistent with prior work on LLM's syntactic processing. This suggests that both methods can uncover meaningful linguistic information. The full line represents the average and the shaded area represents the standard deviation across 5 cross-validation folds.
  • Figure 5: The geometry of LLM representations reflects syntactic structure, especially in middle layers. This figure shows 2D MDS visualizations of sentence representations from each of the LLM's 12 layers. Points are colored by the "Relative Clause type" feature (object relative: blue, subject relative: orange). A clear clustering based on this feature emerges in the middle layers, where the two colors become separable. This geometric separation aligns with the peak importance of the "Relative Clause type" feature shown in \ref{['fig:feature_importance']}, validating that the feature importances identified by the approach correspond to tangible structures in neural space.
  • ...and 15 more figures