Flexible and Efficient Surrogate Gradient Modeling with Forward Gradient Injection
Sebastian Otte
TL;DR
The paper tackles the challenge of implementing surrogate gradients for non-differentiable operations in neural networks, particularly Heaviside functions in spiking neural networks. It introduces Forward Gradient Injection (FGI), which encodes surrogate gradient shapes directly in the forward pass using two mechanisms—gradient bypassing and gradient injection—enabled by a stop-gradient operator. The authors formalize these ideas and demonstrate them in an ALIF-based SNN on Sequential MNIST, showing that FGI can outperform traditional custom-backward methods, especially when combined with TorchScript, and can achieve dramatic speedups with torch.compile. The work highlights FGI as a practical, scalable approach for surrogate-gradient modeling, enabling faster experimentation and optimization for non-differentiable components in autograd-enabled frameworks. Overall, FGI offers a lightweight, effective alternative to manual backward overrides with significant implications for rapid prototyping and performance in SNNs and related models.
Abstract
Automatic differentiation is a key feature of present deep learning frameworks. Moreover, they typically provide various ways to specify custom gradients within the computation graph, which is of particular importance for defining surrogate gradients in the realms of non-differentiable operations such as the Heaviside function in spiking neural networks (SNNs). PyTorch, for example, allows the custom specification of the backward pass of an operation by overriding its backward method. Other frameworks provide comparable options. While these methods are common practice and usually work well, they also have several disadvantages such as limited flexibility, additional source code overhead, poor usability, or a potentially strong negative impact on the effectiveness of automatic model optimization procedures. In this paper, an alternative way to formulate surrogate gradients is presented, namely, forward gradient injection (FGI). FGI applies a simple but effective combination of basic standard operations to inject an arbitrary gradient shape into the computational graph directly within the forward pass. It is demonstrated that using FGI is straightforward and convenient. Moreover, it is shown that FGI can significantly increase the model performance in comparison to custom backward methods in SNNs when using TorchScript. These results are complemented with a general performance study on recurrent SNNs with TorchScript and torch.compile, revealing the potential for a training speedup of more than 7x and an inference speedup of more than 16x in comparison with pure PyTorch.
