WiNet: Wavelet-based Incremental Learning for Efficient Medical Image Registration
Xinxing Cheng, Xi Jia, Wenqi Lu, Qiufu Li, Linlin Shen, Alexander Krull, Jinming Duan
TL;DR
WiNet tackles memory and explainability bottlenecks in 3D medical image registration by introducing a model-driven, wavelet-based approach. It embeds a differentiable DWT in the encoder and uses an Incremental Deformation Learning Module to learn scale-wise wavelet coefficients, reconstructing the full deformation field $oldsymbol{1}{oldsymbol{ abla}}$ via a differentiable IDWT. A diffeomorphic variant employs $oldsymbol{1}{oldsymbol{ abla}}= ext{exp}(oldsymbol{1}{oldsymbol{v}})$ with seven scaling-and-squaring layers and a combined loss on $oldsymbol{1}{oldsymbol{ abla}}$ or $oldsymbol{1}{oldsymbol{v}}$, balancing similarity and smoothness. Across IXI and 3D-CMR datasets, WiNet achieves competitive Dice and lower memory footprints compared with pyramid and model-driven baselines, with WiNet-Diff attaining 0% and 0.007% topology violations on IXI and 3D-CMR respectively, highlighting improved efficiency and explainability for large-volume registration.
Abstract
Deep image registration has demonstrated exceptional accuracy and fast inference. Recent advances have adopted either multiple cascades or pyramid architectures to estimate dense deformation fields in a coarse-to-fine manner. However, due to the cascaded nature and repeated composition/warping operations on feature maps, these methods negatively increase memory usage during training and testing. Moreover, such approaches lack explicit constraints on the learning process of small deformations at different scales, thus lacking explainability. In this study, we introduce a model-driven WiNet that incrementally estimates scale-wise wavelet coefficients for the displacement/velocity field across various scales, utilizing the wavelet coefficients derived from the original input image pair. By exploiting the properties of the wavelet transform, these estimated coefficients facilitate the seamless reconstruction of a full-resolution displacement/velocity field via our devised inverse discrete wavelet transform (IDWT) layer. This approach avoids the complexities of cascading networks or composition operations, making our WiNet an explainable and efficient competitor with other coarse-to-fine methods. Extensive experimental results from two 3D datasets show that our WiNet is accurate and GPU efficient. The code is available at https://github.com/x-xc/WiNet .
