ReMiDi: Reconstruction of Microstructure Using a Differentiable Diffusion MRI Simulator
Prathamesh Pradeep Khole, Zahra Kais Petiwala, Shri Prathaa Magesh, Ehsan Mirafzali, Utkarsh Gupta, Jing-Rebecca Li, Andrada Ianus, Razvan Marinescu
TL;DR
The paper tackles reconstructing neuronal microstructure from diffusion MRI signals by introducing ReMiDi, a differentiable dMRI simulator implemented in PyTorch that backpropagates through Bloch-Torrey PDE-based forward models. It combines a fast matrix-formalism solution with FEM discretization and a latent-space representation via a Spectral Auto-Encoder to regularize the inverse problem, enabling reconstruction of arbitrary 3D meshes that represent axonal geometry. Key contributions include differentiable SpinDoctor-style physics, efficient eigen-decomposition for forward solves, and end-to-end gradient-based mesh reconstruction demonstrated on bent, beaded, and fanned axon geometries, with comparisons to a neural network baseline. The work paves the way for in-vivo mesoscopic brain mapping by enabling geometry-aware diffusion modeling, though challenges remain in GPU memory, multi-compartment modeling, and applying to real-world MRI data.
Abstract
We propose ReMiDi, a novel method for inferring neuronal microstructure as arbitrary 3D meshes using a differentiable diffusion Magnetic Resonance Imaging (dMRI) simulator. We first implemented in PyTorch a differentiable dMRI simulator that simulates the forward diffusion process using a finite-element method on an input 3D microstructure mesh. To achieve significantly faster simulations, we solve the differential equation semi-analytically using a matrix formalism approach. Given a reference dMRI signal $S_{ref}$, we use the differentiable simulator to iteratively update the input mesh such that it matches $S_{ref}$ using gradient-based learning. Since directly optimizing the 3D coordinates of the vertices is challenging, particularly due to ill-posedness of the inverse problem, we instead optimize a lower-dimensional latent space representation of the mesh. The mesh is first encoded into spectral coefficients, which are further encoded into a latent $\textbf{z}$ using an auto-encoder, and are then decoded back into the true mesh. We present an end-to-end differentiable pipeline that simulates signals that can be tuned to match a reference signal by iteratively updating the latent representation $\textbf{z}$. We demonstrate the ability to reconstruct microstructures of arbitrary shapes represented by finite-element meshes, with a focus on axonal geometries found in the brain white matter, including bending, fanning and beading fibers. Our source code is available online.
