Table of Contents
Fetching ...

WiNet: Wavelet-based Incremental Learning for Efficient Medical Image Registration

Xinxing Cheng, Xi Jia, Wenqi Lu, Qiufu Li, Linlin Shen, Alexander Krull, Jinming Duan

TL;DR

WiNet tackles memory and explainability bottlenecks in 3D medical image registration by introducing a model-driven, wavelet-based approach. It embeds a differentiable DWT in the encoder and uses an Incremental Deformation Learning Module to learn scale-wise wavelet coefficients, reconstructing the full deformation field $oldsymbol{1}{oldsymbol{ abla}}$ via a differentiable IDWT. A diffeomorphic variant employs $oldsymbol{1}{oldsymbol{ abla}}= ext{exp}(oldsymbol{1}{oldsymbol{v}})$ with seven scaling-and-squaring layers and a combined loss on $oldsymbol{1}{oldsymbol{ abla}}$ or $oldsymbol{1}{oldsymbol{v}}$, balancing similarity and smoothness. Across IXI and 3D-CMR datasets, WiNet achieves competitive Dice and lower memory footprints compared with pyramid and model-driven baselines, with WiNet-Diff attaining 0% and 0.007% topology violations on IXI and 3D-CMR respectively, highlighting improved efficiency and explainability for large-volume registration.

Abstract

Deep image registration has demonstrated exceptional accuracy and fast inference. Recent advances have adopted either multiple cascades or pyramid architectures to estimate dense deformation fields in a coarse-to-fine manner. However, due to the cascaded nature and repeated composition/warping operations on feature maps, these methods negatively increase memory usage during training and testing. Moreover, such approaches lack explicit constraints on the learning process of small deformations at different scales, thus lacking explainability. In this study, we introduce a model-driven WiNet that incrementally estimates scale-wise wavelet coefficients for the displacement/velocity field across various scales, utilizing the wavelet coefficients derived from the original input image pair. By exploiting the properties of the wavelet transform, these estimated coefficients facilitate the seamless reconstruction of a full-resolution displacement/velocity field via our devised inverse discrete wavelet transform (IDWT) layer. This approach avoids the complexities of cascading networks or composition operations, making our WiNet an explainable and efficient competitor with other coarse-to-fine methods. Extensive experimental results from two 3D datasets show that our WiNet is accurate and GPU efficient. The code is available at https://github.com/x-xc/WiNet .

WiNet: Wavelet-based Incremental Learning for Efficient Medical Image Registration

TL;DR

WiNet tackles memory and explainability bottlenecks in 3D medical image registration by introducing a model-driven, wavelet-based approach. It embeds a differentiable DWT in the encoder and uses an Incremental Deformation Learning Module to learn scale-wise wavelet coefficients, reconstructing the full deformation field via a differentiable IDWT. A diffeomorphic variant employs with seven scaling-and-squaring layers and a combined loss on or , balancing similarity and smoothness. Across IXI and 3D-CMR datasets, WiNet achieves competitive Dice and lower memory footprints compared with pyramid and model-driven baselines, with WiNet-Diff attaining 0% and 0.007% topology violations on IXI and 3D-CMR respectively, highlighting improved efficiency and explainability for large-volume registration.

Abstract

Deep image registration has demonstrated exceptional accuracy and fast inference. Recent advances have adopted either multiple cascades or pyramid architectures to estimate dense deformation fields in a coarse-to-fine manner. However, due to the cascaded nature and repeated composition/warping operations on feature maps, these methods negatively increase memory usage during training and testing. Moreover, such approaches lack explicit constraints on the learning process of small deformations at different scales, thus lacking explainability. In this study, we introduce a model-driven WiNet that incrementally estimates scale-wise wavelet coefficients for the displacement/velocity field across various scales, utilizing the wavelet coefficients derived from the original input image pair. By exploiting the properties of the wavelet transform, these estimated coefficients facilitate the seamless reconstruction of a full-resolution displacement/velocity field via our devised inverse discrete wavelet transform (IDWT) layer. This approach avoids the complexities of cascading networks or composition operations, making our WiNet an explainable and efficient competitor with other coarse-to-fine methods. Extensive experimental results from two 3D datasets show that our WiNet is accurate and GPU efficient. The code is available at https://github.com/x-xc/WiNet .
Paper Structure (9 sections, 4 equations, 4 figures, 1 table)

This paper contains 9 sections, 4 equations, 4 figures, 1 table.

Figures (4)

  • Figure 1: Architecture of WiNet. Its encoder includes a shared DWT layer and four convolutional contracting layers. The decoder consists of three convolutional expansion layers and an Incremental Deformation Learning Module (which converts the convolutional features into final displacement $\boldsymbol{\phi}$ with a conv-layer (Conv-0), two refinement blocks (RB-1 and RB-2), and a shared parameter-free IDWT layer). Note that the 3D images, features, and displacements are shown in 2D for illustration.
  • Figure 2: Illustration of the coarse-to-fine incremental deformation learning module.
  • Figure 3: Comparison of the number of parameters (the area of circles), GPU training memory on IXI, and averaged GPU inference time on the same machine. Fourier-Net, TM-B-Spline-Diff, and B-Spline-Diff are abbreviated as F-Net, TM-BS-Diff, and BS-Diff, respectively. Our WiNet requires the lowest memory (6977 MiB) for training.
  • Figure 4: Visual comparisons on IXI and 3D-CMR. Column 1: Fixed image, Moving image. Columns 2-10: warped moving images, displacement fields as RGB images.