How not to Stitch Representations to Measure Similarity: Task Loss Matching versus Direct Matching
András Balogh, Márk Jelasity
TL;DR
This work analyzes how to quantify similarity between neural representations, arguing that purely task-driven functional similarity can produce misleading, out-of-distribution representations. It compares task loss matching (TLM), direct matching (DM), and structural indices (CKA, PWCCA, OPD) within a model-stitching framework and demonstrates that TLM often yields OOD representations that violate basic sanity checks, while DM preserves in-distribution structure and aligns with functional compatibility. The authors provide extensive empirical evidence across ResNet and ViT architectures on CIFAR-10, SVHN, and ImageNet, showing that DM offers a more reliable, hybrid notion of similarity and challenges conclusions drawn from OPD and related methods. They further show that DM remains robust under statistical probing tests, whereas TLM’s conclusions are undermined by OOD effects, suggesting a practical bias toward hybrid, structure-function-consistent similarity indices for comparing representations across models and tasks.
Abstract
Measuring the similarity of the internal representations of deep neural networks is an important and challenging problem. Model stitching has been proposed as a possible approach, where two half-networks are connected by mapping the output of the first half-network to the input of the second one. The representations are considered functionally similar if the resulting stitched network achieves good task-specific performance. The mapping is normally created by training an affine stitching layer on the task at hand while freezing the two half-networks, a method called task loss matching. Here, we argue that task loss matching may be very misleading as a similarity index. For example, it can indicate very high similarity between very distant layers, whose representations are known to have different functional properties. Moreover, it can indicate very distant layers to be more similar than architecturally corresponding layers. Even more surprisingly, when comparing layers within the same network, task loss matching often indicates that some layers are more similar to a layer than itself. We argue that the main reason behind these problems is that task loss matching tends to create out-of-distribution representations to improve task-specific performance. We demonstrate that direct matching (when the mapping minimizes the distance between the stitched representations) does not suffer from these problems. We compare task loss matching, direct matching, and well-known similarity indices such as CCA and CKA. We conclude that direct matching strikes a good balance between the structural and functional requirements for a good similarity index.
