Table of Contents
Fetching ...

Model Alignment Search

Satchel Grant

TL;DR

Model Alignment Search (MAS) proposes a causal, bidirectional approach to neural similarity by learning per-model invertible alignments and performing interchange interventions on behaviorally relevant subspaces. It situates MAS between model stitching and Distributed Alignment Search (DAS), offering a more behavior-driven measure of similarity that reduces computational requirements and reveals subtleties that correlational methods miss. The paper demonstrates MAS on numeric task families and toxicity-model case studies, introduces a Counterfactual Latent (CL) auxiliary objective for inaccessible networks (CLMAS), and shows that MAS can isolate specific causal information while maintaining competitive alignment with fewer resources. The findings advocate for causal methods in neural similarity analysis and outline directions for extending network alignment methodologies, including biological applications and larger-scale multi-model comparisons.

Abstract

When can we say that two neural systems perform a task in the same way? What nuances do we miss when we fail to causally probe the representations of the systems, and how do we establish bidirectional causal relationships? In this work, we introduce a method that bidirectionally transfers neural activity between artificial neural networks and uses their resulting behavior as a measure of functional similarity. We first show that the method can be used to transfer the behavior from one frozen Neural Network (NN) to another in a manner similar to model stitching, and we show how the method can differ from correlative similarity measures like Representational Similarity Analysis. Next, we empirically and theoretically show how the method can be equivalent to model stitching when desired, or it can take a form that has a more restrictive focus to shared causal information; in both forms, it reduces the number of required matrices for a comparison of n models to be linear in n. We then present a case study on number-related tasks showing that the method can be used to examine specific subtypes of causal information demonstrating that numbers can be encoded differently in recurrent models depending on the task, and we present another case study showing that MAS can reveal misalignment in fine-tuned DeepSeek-r1-Qwen-1.5B models. Lastly, we augment the loss function with a counterfactual latent (CL) auxiliary objective to improve causal relevance when one of the two networks is causally inaccessible (as is often the case in comparisons with biological networks). We use our results to encourage the use of causal methods in neural similarity analyses and to suggest future explorations of network similarity methodology for model misalignment.

Model Alignment Search

TL;DR

Model Alignment Search (MAS) proposes a causal, bidirectional approach to neural similarity by learning per-model invertible alignments and performing interchange interventions on behaviorally relevant subspaces. It situates MAS between model stitching and Distributed Alignment Search (DAS), offering a more behavior-driven measure of similarity that reduces computational requirements and reveals subtleties that correlational methods miss. The paper demonstrates MAS on numeric task families and toxicity-model case studies, introduces a Counterfactual Latent (CL) auxiliary objective for inaccessible networks (CLMAS), and shows that MAS can isolate specific causal information while maintaining competitive alignment with fewer resources. The findings advocate for causal methods in neural similarity analysis and outline directions for extending network alignment methodologies, including biological applications and larger-scale multi-model comparisons.

Abstract

When can we say that two neural systems perform a task in the same way? What nuances do we miss when we fail to causally probe the representations of the systems, and how do we establish bidirectional causal relationships? In this work, we introduce a method that bidirectionally transfers neural activity between artificial neural networks and uses their resulting behavior as a measure of functional similarity. We first show that the method can be used to transfer the behavior from one frozen Neural Network (NN) to another in a manner similar to model stitching, and we show how the method can differ from correlative similarity measures like Representational Similarity Analysis. Next, we empirically and theoretically show how the method can be equivalent to model stitching when desired, or it can take a form that has a more restrictive focus to shared causal information; in both forms, it reduces the number of required matrices for a comparison of n models to be linear in n. We then present a case study on number-related tasks showing that the method can be used to examine specific subtypes of causal information demonstrating that numbers can be encoded differently in recurrent models depending on the task, and we present another case study showing that MAS can reveal misalignment in fine-tuned DeepSeek-r1-Qwen-1.5B models. Lastly, we augment the loss function with a counterfactual latent (CL) auxiliary objective to improve causal relevance when one of the two networks is causally inaccessible (as is often the case in comparisons with biological networks). We use our results to encourage the use of causal methods in neural similarity analyses and to suggest future explorations of network similarity methodology for model misalignment.
Paper Structure (30 sections, 17 equations, 6 figures)

This paper contains 30 sections, 17 equations, 6 figures.

Figures (6)

  • Figure 1: (a) A depiction of an interchange intervention from Equation \ref{['eq:masinterchange']} on the target latent vector $h^{(trg)}_{\psi_i}$ from model $\psi_i$ using the source latent vector $h^{(src)}_{\psi_k}$ from $\psi_k$. Rectangles represent vectors; colors distinguish between behaviorally relevant and extraneous activity. The causally relevant information is spread across the neural population Smolensky1988park2023linearrephypoth, represented by the red and green semi-vertical slices in the respective $h$ vectors. To disentangle and isolate the behavioral null-space, we rotate the $h$ vectors into an aligned space, using $Q_{\psi_i}$ and $Q_{\psi_k}$, where the behaviorally relevant information is organized along separate dimensions than the extraneous, behavioral null-space. We can then intervene on and transfer the relevant information without affecting other information. In the figure, this is done by applying binary masks (black represents 0s and white 1s and $\odot$ is a Hadamard product) to the $z$ vectors and taking their sum. We then invert the rotation on $z^{(v)}_{\psi_i}$ to return it to the original latent space where it can be used by $\psi_i$ to make predictions. (b) Depicts Stepwise MAS where the individual intervention shown in panel (a) is applied at multiple token positions in the sequence. We limit our interventions to contiguous sets of tokens starting from the first token and ending with a sampled position $t$.
  • Figure 2: (a) and (b) A comparison of MAS on the left axes and CKA and RSA on the right axes. We examine both the input embeddings and the hidden state vectors for models trained on the Multi-Object task. Dashed lines indicate the values for comparing individual models with themselves. (a) Results for GRUs compared to GRUs, where RSA can give low estimates of embedding similarity for different model seeds whereas MAS shows the high causal transfer we might expect. (b) Results for GRUs compared to 2-layer Transformers where we see a similar effect as (a) in the embeddings using RSA and we see a potential over-estimation of similarity of the hidden states in CKA and RSA. This over-estimation is with respect to causal transfer, as prior work has shown the transformers to use anti-Markovian solutions, where they recompute the relevant information at each step in the sequence. This is reflected in the low MAS IIA grant2024das. (c) IIA comparing finetuned toxic and nontoxic LLMs using stepwise MAS. We can see that toxic models have higher IIA with themselves than with the nontoxic models. Notably, there is no significant difference for the nontoxic models compared to themselves. (d) Comparison of the IIA from DAS and MAS for different sizes of the aligned subspace, and model stitching with different rank transformation matrices.
  • Figure 3: (a) Comparison of the IIAs for CLMAS, Latent Stitching, and behavioral Stitching in different intervention directions on the Multi-Object GRU models. The dashed line indicates MAS IIA as an upper bound. On the x-axis, Access refers to interchange interventions from the inaccessible $\tilde{\psi}_1^{(src)}$ to the accessible $\psi_2^{(trg)}$. No Access refers to interventions from the accessible $\psi_2^{(src)}$ to the inaccessible $\tilde{\psi}_1^{(trg)}$. The Stitch results are from model stitching trainings from $\tilde{\psi}_1^{(src)}$ to $\psi_2^{(trg)}$, and the Latent Stitch results are trained in the inaccessible direction from $\psi_2^{(trg)}$ to $\tilde{\psi}_1^{(src)}$ without behavioral training. Both Access and No Access values are reported for the training step with the best No Access IIA (which is why Stitching does not have 100% IIA in the Access direction). We see that CLMAS has the best performance in the No Access direction. (b) A comparison of the transferrability of the behaviorally relevant numeric information between the Multi-Object GRU models and the Multi-Object, Rounding, Modulo, and Same-Object models. DAS shows an upper bound on the MAS performance which would result in the case that $\psi_1$ and $\psi_2$ represent numbers the same way. (c) Example token sequences of the GRU tasks from panel (b).
  • Figure 4: Figure and caption taken from grant2024das. Diagram of the transformer architecture used in this work. White rectangles represent activation vectors, arrows represent functional operations. All causal interventions were performed on either the Hidden State activations from Layer 1 or the Embeddings layer. All normalizations are Layer Norms ba2016layernorm.
  • Figure 5: (a) Diagram of MAS between models trained on structurally different tasks. We see all four intervention directions on the latent vectors (rectangles) of a Multi-Object GRU, $\psi_1$, and an Arithmetic GRU $\psi_2$. The value of the causal variables following each input token are shown above the $h_{src}$ and below the $h_{trg}$ vectors. In the $h_{trg}$ vectors, we see the variable value before and after the intervention to the left and right of the arrows respectively---the Count for $\psi_1$ and Rem Ops for $\psi_2$. The dotted Substitution arrows each correspond to a single intervention. The models make predictions using the intervened vector following the intervention. (b) A 2D vector depiction of a hypothetical intervention that substitutes the value of the Count variable from $\psi_i$, $\Vec{z}^{\,(1)}_{count}$, into the Count variable, $\Vec{z}^{\,(2)}_{count}$ in $\psi_j$ where the superscripts refer to the originating model. Using learned matrices $Q_1$ and $Q_2$, $h_{src}^{(1)}$ and $h_{trg}^{(2)}$ are rotated into $z_{src}^{(1)}$ and $z_{trg}^{(2)}$ where the $\Vec{z}_{count}$ subspace is organized into a contiguous subset of vector dimensions disentangled from all other information. In this aligned space the $\Vec{z}_{count}$ values can be freely exchanged without affecting other information. Lastly, $z_{\text{v}}^{(2)}$ is returned to $\psi_2$'s hidden state space by inverting $Q_2$ and is used for inference by $\psi_2$.
  • ...and 1 more figures