Three-dimensional rigid-body transforms, i.e. rotations and translations, are central to modern differentiable machine learning pipelines in robotics, vision, and simulation. However, numerically robust and mathematically correct implementations, particularly on SO(3), are error-prone due to issues such as axis conventions, normalizations, composition consistency and subtle errors that only appear in edge cases. SciPy's spatial.transform module is a rigorously tested Python implementation. However, it historically only supported NumPy, limiting adoption in GPU-accelerated and autodiff-based workflows. We present a complete overhaul of SciPy's spatial.transform functionality that makes it compatible with any array library implementing the Python array API, including JAX, PyTorch, and CuPy. The revised implementation preserves the established SciPy interface while enabling GPU/TPU execution, JIT compilation, vectorized batching, and differentiation via native autodiff of the chosen backend. We demonstrate how this foundation supports differentiable scientific computing through two case studies: (i) scalability of 3D transforms and rotations and (ii) a JAX drone simulation that leverages SciPy's Rotation for accurate integration of rotational dynamics. Our contributions have been merged into SciPy main and will ship in the next release, providing a framework-agnostic, production-grade basis for 3D spatial math in differentiable systems and ML.
翻译:三维刚体变换(即旋转与平移)是现代机器人学、视觉与仿真领域中可微分机器学习流程的核心。然而,数值鲁棒且数学正确的实现(特别是在SO(3)上)极易出错,原因包括轴系约定、归一化处理、组合一致性以及仅在边缘情况下显现的细微误差等问题。SciPy的spatial.transform模块是一个经过严格测试的Python实现,但历史上仅支持NumPy,限制了其在GPU加速和基于自动微分的工作流中的应用。本文介绍了对SciPy spatial.transform功能的全面重构,使其兼容任何实现Python数组API的数组库(包括JAX、PyTorch和CuPy)。改进后的实现保留了SciPy的既定接口,同时支持通过所选后端的原生自动微分实现GPU/TPU执行、即时编译、向量化批处理及微分运算。我们通过两个案例研究展示了该基础如何支持可微分科学计算:(i)三维变换与旋转的可扩展性;(ii)利用SciPy Rotation进行旋转动力学精确积分的JAX无人机仿真。我们的贡献已并入SciPy主分支,将在下一版本中发布,为可微分系统与机器学习中的三维空间数学提供框架无关、生产级的基础支撑。