Measuring and regularizing networks in function space
Ari S. Benjamin, David Rolnick, Konrad Kording
TL;DR
This work argues that the input/output function of a neural network should be analyzed in function space using the $L^2$ metric, rather than relying solely on parameter-space distances. It shows that function-space trajectories during training can differ markedly from parameter-space paths and that the relation between the two evolves as optimization proceeds. The authors introduce two practical approaches: a working-memory–based function-space regularization to mitigate catastrophic forgetting, and Hilbert-constrained gradient descent (HCGD) to bound per-update changes in function space, linking to the natural gradient. Their experiments demonstrate context-dependent benefits, with potential gains in continual learning and certain regimens of recurrent tasks, highlighting the value of directly regularizing function changes. Overall, the paper provides a framework for measuring and regularizing function distances, offering a complementary perspective to traditional parameter-focused optimization.
Abstract
To optimize a neural network one often thinks of optimizing its parameters, but it is ultimately a matter of optimizing the function that maps inputs to outputs. Since a change in the parameters might serve as a poor proxy for the change in the function, it is of some concern that primacy is given to parameters but that the correspondence has not been tested. Here, we show that it is simple and computationally feasible to calculate distances between functions in a $L^2$ Hilbert space. We examine how typical networks behave in this space, and compare how parameter $\ell^2$ distances compare to function $L^2$ distances between various points of an optimization trajectory. We find that the two distances are nontrivially related. In particular, the $L^2/\ell^2$ ratio decreases throughout optimization, reaching a steady value around when test error plateaus. We then investigate how the $L^2$ distance could be applied directly to optimization. We first propose that in multitask learning, one can avoid catastrophic forgetting by directly limiting how much the input/output function changes between tasks. Secondly, we propose a new learning rule that constrains the distance a network can travel through $L^2$-space in any one update. This allows new examples to be learned in a way that minimally interferes with what has previously been learned. These applications demonstrate how one can measure and regularize function distances directly, without relying on parameters or local approximations like loss curvature.
