Table of Contents
Fetching ...

scipy.spatial.transform: Differentiable Framework-Agnostic 3D Transformations in Python

Martin Schuck, Alexander von Rohr, Angela P. Schoellig

TL;DR

Robust $SO(3)$/$SE(3)$ spatial transforms are essential for modern differentiable pipelines but are error-prone when reimplemented across frameworks. We present a framework-agnostic overhaul of SciPy's spatial.transform aligned with the Python array API, enabling backend-agnostic execution on CPU/GPU/TPU with native autodiff and JIT. Key contributions include a backend-selecting architecture, unit-quaternion rotations, homogeneous transformation representations, and efficient broadcasting across arbitrary batch dimensions, plus case studies on performance across frameworks and a differentiable drone simulator. The results demonstrate portability, numerical robustness, and production-grade capabilities, reducing framework-specific reimplementation and enabling differentiable spatial computation in ML pipelines.

Abstract

Three-dimensional rigid-body transforms, i.e. rotations and translations, are central to modern differentiable machine learning pipelines in robotics, vision, and simulation. However, numerically robust and mathematically correct implementations, particularly on SO(3), are error-prone due to issues such as axis conventions, normalizations, composition consistency and subtle errors that only appear in edge cases. SciPy's spatial$.$transform module is a rigorously tested Python implementation. However, it historically only supported NumPy, limiting adoption in GPU-accelerated and autodiff-based workflows. We present a complete overhaul of SciPy's spatial$.$transform functionality that makes it compatible with any array library implementing the Python array API, including JAX, PyTorch, and CuPy. The revised implementation preserves the established SciPy interface while enabling GPU/TPU execution, JIT compilation, vectorized batching, and differentiation via native autodiff of the chosen backend. We demonstrate how this foundation supports differentiable scientific computing through two case studies: (i) scalability of 3D transforms and rotations and (ii) a JAX drone simulation that leverages SciPy's Rotation for accurate integration of rotational dynamics. Our contributions have been merged into SciPy main and will ship in the next release, providing a framework-agnostic, production-grade basis for 3D spatial math in differentiable systems and ML.

scipy.spatial.transform: Differentiable Framework-Agnostic 3D Transformations in Python

TL;DR

Robust / spatial transforms are essential for modern differentiable pipelines but are error-prone when reimplemented across frameworks. We present a framework-agnostic overhaul of SciPy's spatial.transform aligned with the Python array API, enabling backend-agnostic execution on CPU/GPU/TPU with native autodiff and JIT. Key contributions include a backend-selecting architecture, unit-quaternion rotations, homogeneous transformation representations, and efficient broadcasting across arbitrary batch dimensions, plus case studies on performance across frameworks and a differentiable drone simulator. The results demonstrate portability, numerical robustness, and production-grade capabilities, reducing framework-specific reimplementation and enabling differentiable spatial computation in ML pipelines.

Abstract

Three-dimensional rigid-body transforms, i.e. rotations and translations, are central to modern differentiable machine learning pipelines in robotics, vision, and simulation. However, numerically robust and mathematically correct implementations, particularly on SO(3), are error-prone due to issues such as axis conventions, normalizations, composition consistency and subtle errors that only appear in edge cases. SciPy's spatialtransform module is a rigorously tested Python implementation. However, it historically only supported NumPy, limiting adoption in GPU-accelerated and autodiff-based workflows. We present a complete overhaul of SciPy's spatialtransform functionality that makes it compatible with any array library implementing the Python array API, including JAX, PyTorch, and CuPy. The revised implementation preserves the established SciPy interface while enabling GPU/TPU execution, JIT compilation, vectorized batching, and differentiation via native autodiff of the chosen backend. We demonstrate how this foundation supports differentiable scientific computing through two case studies: (i) scalability of 3D transforms and rotations and (ii) a JAX drone simulation that leverages SciPy's Rotation for accurate integration of rotational dynamics. Our contributions have been merged into SciPy main and will ship in the next release, providing a framework-agnostic, production-grade basis for 3D spatial math in differentiable systems and ML.

Paper Structure

This paper contains 6 sections, 1 equation, 3 figures.

Figures (3)

  • Figure 1: Overview over the changed architecture using Rotation as example. Instead of a pure Cython implementation which is only compatible with NumPy, the classes are now pure Python. Each method delegates computations to a backend that is selected depending on the array type. NumPy arrays are passed to the specialized Cython backend, whereas other array API frameworks use the new, generic backend. This enables advanced capabilities such as differentiation through transforms.
  • Figure 2: Computation time versus number of samples $N$ for the multiply and apply operations of Rotation and RigidTransform across array API backends. JAX timings are evaluated after JIT compilation. GPU backends incur a fixed overhead at small $N$ but achieve higher asymptotic throughput for large $N$. Notably, JAX often performs on par or better than the custom Cython backend for numpy.
  • Figure 3: Several drone trajectories optimized with fully differentiable dynamics and drone controllers leveraging scipy.spatial.transform.Rotation. Optimization progress from left to right, reference positions in green. The visualization is based on MuJoCo todorov2012mujoco.