Table of Contents
Fetching ...

ReMiDi: Reconstruction of Microstructure Using a Differentiable Diffusion MRI Simulator

Prathamesh Pradeep Khole, Zahra Kais Petiwala, Shri Prathaa Magesh, Ehsan Mirafzali, Utkarsh Gupta, Jing-Rebecca Li, Andrada Ianus, Razvan Marinescu

TL;DR

The paper tackles reconstructing neuronal microstructure from diffusion MRI signals by introducing ReMiDi, a differentiable dMRI simulator implemented in PyTorch that backpropagates through Bloch-Torrey PDE-based forward models. It combines a fast matrix-formalism solution with FEM discretization and a latent-space representation via a Spectral Auto-Encoder to regularize the inverse problem, enabling reconstruction of arbitrary 3D meshes that represent axonal geometry. Key contributions include differentiable SpinDoctor-style physics, efficient eigen-decomposition for forward solves, and end-to-end gradient-based mesh reconstruction demonstrated on bent, beaded, and fanned axon geometries, with comparisons to a neural network baseline. The work paves the way for in-vivo mesoscopic brain mapping by enabling geometry-aware diffusion modeling, though challenges remain in GPU memory, multi-compartment modeling, and applying to real-world MRI data.

Abstract

We propose ReMiDi, a novel method for inferring neuronal microstructure as arbitrary 3D meshes using a differentiable diffusion Magnetic Resonance Imaging (dMRI) simulator. We first implemented in PyTorch a differentiable dMRI simulator that simulates the forward diffusion process using a finite-element method on an input 3D microstructure mesh. To achieve significantly faster simulations, we solve the differential equation semi-analytically using a matrix formalism approach. Given a reference dMRI signal $S_{ref}$, we use the differentiable simulator to iteratively update the input mesh such that it matches $S_{ref}$ using gradient-based learning. Since directly optimizing the 3D coordinates of the vertices is challenging, particularly due to ill-posedness of the inverse problem, we instead optimize a lower-dimensional latent space representation of the mesh. The mesh is first encoded into spectral coefficients, which are further encoded into a latent $\textbf{z}$ using an auto-encoder, and are then decoded back into the true mesh. We present an end-to-end differentiable pipeline that simulates signals that can be tuned to match a reference signal by iteratively updating the latent representation $\textbf{z}$. We demonstrate the ability to reconstruct microstructures of arbitrary shapes represented by finite-element meshes, with a focus on axonal geometries found in the brain white matter, including bending, fanning and beading fibers. Our source code is available online.

ReMiDi: Reconstruction of Microstructure Using a Differentiable Diffusion MRI Simulator

TL;DR

The paper tackles reconstructing neuronal microstructure from diffusion MRI signals by introducing ReMiDi, a differentiable dMRI simulator implemented in PyTorch that backpropagates through Bloch-Torrey PDE-based forward models. It combines a fast matrix-formalism solution with FEM discretization and a latent-space representation via a Spectral Auto-Encoder to regularize the inverse problem, enabling reconstruction of arbitrary 3D meshes that represent axonal geometry. Key contributions include differentiable SpinDoctor-style physics, efficient eigen-decomposition for forward solves, and end-to-end gradient-based mesh reconstruction demonstrated on bent, beaded, and fanned axon geometries, with comparisons to a neural network baseline. The work paves the way for in-vivo mesoscopic brain mapping by enabling geometry-aware diffusion modeling, though challenges remain in GPU memory, multi-compartment modeling, and applying to real-world MRI data.

Abstract

We propose ReMiDi, a novel method for inferring neuronal microstructure as arbitrary 3D meshes using a differentiable diffusion Magnetic Resonance Imaging (dMRI) simulator. We first implemented in PyTorch a differentiable dMRI simulator that simulates the forward diffusion process using a finite-element method on an input 3D microstructure mesh. To achieve significantly faster simulations, we solve the differential equation semi-analytically using a matrix formalism approach. Given a reference dMRI signal , we use the differentiable simulator to iteratively update the input mesh such that it matches using gradient-based learning. Since directly optimizing the 3D coordinates of the vertices is challenging, particularly due to ill-posedness of the inverse problem, we instead optimize a lower-dimensional latent space representation of the mesh. The mesh is first encoded into spectral coefficients, which are further encoded into a latent using an auto-encoder, and are then decoded back into the true mesh. We present an end-to-end differentiable pipeline that simulates signals that can be tuned to match a reference signal by iteratively updating the latent representation . We demonstrate the ability to reconstruct microstructures of arbitrary shapes represented by finite-element meshes, with a focus on axonal geometries found in the brain white matter, including bending, fanning and beading fibers. Our source code is available online.

Paper Structure

This paper contains 25 sections, 18 equations, 13 figures, 1 table, 2 algorithms.

Figures (13)

  • Figure 1: Overview of the ReMiDi-based dMRI reconstruction pipeline. A 3D mesh decoded by a Spectral Auto-Encoder (SAE) is given as input to the differentiable dMRI simulator. The final loss function between the simulated dMRI signal and a reference signal is backpropagated to the latent of the SAE, and the mesh is iteratively updated with gradient-based learning.
  • Figure 2: The signal generated by ReMiDi for a bent axon closely matches SpinDoctor's MATLAB implementation. As opposed to SpinDoctor, ReMiDi is implemented in PyTorch and is fully differentiable.
  • Figure 3: Overview of iterative reconstruction of a bent axon by ReMiDi at different gradient descent iterations. The ground-truth mesh is shown to the left. Top row: Evolution of the triangulated mesh surface. Bottom row: Corresponding point cloud representations of mesh vertices. The color scale shows Chamfer distance (lower values are better).
  • Figure 4: dMRI signal loss over iterations for reconstruction of a bent axon using ReMiDi
  • Figure 5: Reconstruction error as modified Chamfer distance between reference meshes (middle) and reconstructed meshes (bottom), for increasing bending in an axon.
  • ...and 8 more figures