Spatial regularisation for improved accuracy and interpretability in keypoint-based registration
Benjamin Billot, Ramya Muthukrishnan, Esra Abaci-Turk, P. Ellen Grant, Nicholas Ayache, Hervé Delingette, Polina Golland
TL;DR
This work tackles the interpretability bottleneck in unsupervised keypoint-based medical image registration by introducing a principled three-fold spatial regularisation loss. The components $L_{KL}$, $L_{var}$, and $L_{rep}$ shape feature maps into interpretable point spread functions, sharpen landmark precision, and enforce landmark diversity, respectively, within a unified training objective $L_{training} = L_{sim} + \lambda_{KL} L_{KL} + \lambda_{var} L_{var} + \lambda_{rep} L_{rep}$. When applied to EquiTrack and KeyMorph, the regularisation yields accurate, anatomically meaningful keypoints, significantly improving rigid motion tracking in foetal MRI and affine brain MRI registration, and approaching supervised performance while maintaining unsupervised learning advantages. The approach enhances interpretability and reliability of landmark-based registration in clinical imaging, with code available for reproducibility and potential extensions to cross-modality scenarios and dynamic keypoint counts.
Abstract
Unsupervised registration strategies bypass requirements in ground truth transforms or segmentations by optimising similarity metrics between fixed and moved volumes. Among these methods, a recent subclass of approaches based on unsupervised keypoint detection stand out as very promising for interpretability. Specifically, these methods train a network to predict feature maps for fixed and moving images, from which explainable centres of mass are computed to obtain point clouds, that are then aligned in closed-form. However, the features returned by the network often yield spatially diffuse patterns that are hard to interpret, thus undermining the purpose of keypoint-based registration. Here, we propose a three-fold loss to regularise the spatial distribution of the features. First, we use the KL divergence to model features as point spread functions that we interpret as probabilistic keypoints. Then, we sharpen the spatial distributions of these features to increase the precision of the detected landmarks. Finally, we introduce a new repulsive loss across keypoints to encourage spatial diversity. Overall, our loss considerably improves the interpretability of the features, which now correspond to precise and anatomically meaningful landmarks. We demonstrate our three-fold loss in foetal rigid motion tracking and brain MRI affine registration tasks, where it not only outperforms state-of-the-art unsupervised strategies, but also bridges the gap with state-of-the-art supervised methods. Our code is available at https://github.com/BenBillot/spatial_regularisation.
