Moonwalk: Inverse-Forward Differentiation
Dmitrii Krylov, Armin Karamzade, Roy Fox
TL;DR
Moonwalk tackles the memory bottleneck of Backprop in invertible networks by enabling true gradients through forward-mode differentiation. It introduces a two-phase approach that first computes an input gradient $h_0$ and then uses a vector-inverse-Jacobian product to obtain layer-wise gradients, with Pure-Forward and Mixed-Mode variants offering different time-memory trade-offs. Theoretical analysis shows Moonwalk dramatically reduces memory and, in many cases, time, approaching Backprop-like speed when combined with reverse-mode in Mixed-Mode. Empirical results on a CIFAR-10 RevNet demonstrate substantial memory savings and large speedups (e.g., up to 27× faster for 6 layers and 110× faster for 60 layers) while maintaining gradient fidelity and numerical stability. Overall, Moonwalk provides a practical, scalable path to exact gradient computation in invertible networks with far lower memory footprints than Backprop.
Abstract
Backpropagation, while effective for gradient computation, falls short in addressing memory consumption, limiting scalability. This work explores forward-mode gradient computation as an alternative in invertible networks, showing its potential to reduce the memory footprint without substantial drawbacks. We introduce a novel technique based on a vector-inverse-Jacobian product that accelerates the computation of forward gradients while retaining the advantages of memory reduction and preserving the fidelity of true gradients. Our method, Moonwalk, has a time complexity linear in the depth of the network, unlike the quadratic time complexity of naïve forward, and empirically reduces computation time by several orders of magnitude without allocating more memory. We further accelerate Moonwalk by combining it with reverse-mode differentiation to achieve time complexity comparable with backpropagation while maintaining a much smaller memory footprint. Finally, we showcase the robustness of our method across several architecture choices. Moonwalk is the first forward-based method to compute true gradients in invertible networks in computation time comparable to backpropagation and using significantly less memory.
