Table of Contents
Fetching ...

Measuring and regularizing networks in function space

Ari S. Benjamin, David Rolnick, Konrad Kording

TL;DR

This work argues that the input/output function of a neural network should be analyzed in function space using the $L^2$ metric, rather than relying solely on parameter-space distances. It shows that function-space trajectories during training can differ markedly from parameter-space paths and that the relation between the two evolves as optimization proceeds. The authors introduce two practical approaches: a working-memory–based function-space regularization to mitigate catastrophic forgetting, and Hilbert-constrained gradient descent (HCGD) to bound per-update changes in function space, linking to the natural gradient. Their experiments demonstrate context-dependent benefits, with potential gains in continual learning and certain regimens of recurrent tasks, highlighting the value of directly regularizing function changes. Overall, the paper provides a framework for measuring and regularizing function distances, offering a complementary perspective to traditional parameter-focused optimization.

Abstract

To optimize a neural network one often thinks of optimizing its parameters, but it is ultimately a matter of optimizing the function that maps inputs to outputs. Since a change in the parameters might serve as a poor proxy for the change in the function, it is of some concern that primacy is given to parameters but that the correspondence has not been tested. Here, we show that it is simple and computationally feasible to calculate distances between functions in a $L^2$ Hilbert space. We examine how typical networks behave in this space, and compare how parameter $\ell^2$ distances compare to function $L^2$ distances between various points of an optimization trajectory. We find that the two distances are nontrivially related. In particular, the $L^2/\ell^2$ ratio decreases throughout optimization, reaching a steady value around when test error plateaus. We then investigate how the $L^2$ distance could be applied directly to optimization. We first propose that in multitask learning, one can avoid catastrophic forgetting by directly limiting how much the input/output function changes between tasks. Secondly, we propose a new learning rule that constrains the distance a network can travel through $L^2$-space in any one update. This allows new examples to be learned in a way that minimally interferes with what has previously been learned. These applications demonstrate how one can measure and regularize function distances directly, without relying on parameters or local approximations like loss curvature.

Measuring and regularizing networks in function space

TL;DR

This work argues that the input/output function of a neural network should be analyzed in function space using the metric, rather than relying solely on parameter-space distances. It shows that function-space trajectories during training can differ markedly from parameter-space paths and that the relation between the two evolves as optimization proceeds. The authors introduce two practical approaches: a working-memory–based function-space regularization to mitigate catastrophic forgetting, and Hilbert-constrained gradient descent (HCGD) to bound per-update changes in function space, linking to the natural gradient. Their experiments demonstrate context-dependent benefits, with potential gains in continual learning and certain regimens of recurrent tasks, highlighting the value of directly regularizing function changes. Overall, the paper provides a framework for measuring and regularizing function distances, offering a complementary perspective to traditional parameter-focused optimization.

Abstract

To optimize a neural network one often thinks of optimizing its parameters, but it is ultimately a matter of optimizing the function that maps inputs to outputs. Since a change in the parameters might serve as a poor proxy for the change in the function, it is of some concern that primacy is given to parameters but that the correspondence has not been tested. Here, we show that it is simple and computationally feasible to calculate distances between functions in a Hilbert space. We examine how typical networks behave in this space, and compare how parameter distances compare to function distances between various points of an optimization trajectory. We find that the two distances are nontrivially related. In particular, the ratio decreases throughout optimization, reaching a steady value around when test error plateaus. We then investigate how the distance could be applied directly to optimization. We first propose that in multitask learning, one can avoid catastrophic forgetting by directly limiting how much the input/output function changes between tasks. Secondly, we propose a new learning rule that constrains the distance a network can travel through -space in any one update. This allows new examples to be learned in a way that minimally interferes with what has previously been learned. These applications demonstrate how one can measure and regularize function distances directly, without relying on parameters or local approximations like loss curvature.

Paper Structure

This paper contains 18 sections, 15 equations, 13 figures, 3 algorithms.

Figures (13)

  • Figure 1: Visualization of the trajectories of three random initializations of a network through function space, left, and parameter space, right. The network is a convolutional network trained on a 5,000 image subset of CIFAR-10. At each epoch, we compute the $L^2$ and $\ell^2$ distances between all previous epochs, forming two distance matrices, and then recompute the 2D embedding from these matrices using multidimensional scaling. Each point on the plots represents the network at a new epoch of training.The black arrows represent the direction of movement.
  • Figure 2: Parameter distances is sometimes, but not always, representative of function distances. Here we compare the two at three scales during the optimization of a CNN on CIFAR-10. Left: Distances between the individual SGD updates. Middle: Distances between each epoch. Right: Distances from initialization. On all three plots, note the changing relationship between function and parameter distances throughout optimization. The network is the same as in Figure \ref{['fig:par_fn_space']}: a CNN with four convolutional layers with batch normalization, followed by two fully-connected layers, trained with SGD with learning rate = 0.1, momentum = 0.9, and weight decay = 1e-4. Note that the $L^2$ distance is computed from the output after the softmax layer, meaning possible values range from 0 to 1.
  • Figure 3: The variance of the the $L^2$ estimator is small enough that it can be reasonably estimated from a few hundred examples. In panels A and D, we reproduced $L^2$ distances seen in the panels of Fig. \ref{['fig:lipschitz']}. As we increase the number of validation examples these distances are computed over, the estimations become more accurate. Panels B and E show the 95% confidence bounds for the estimation; on 95% of batches, the value will lie bewteen these bounds. These bounds can be obtained from the standard deviation of the $L^2$ distance on single examples. In panel C we show that the standard deviation scales linearly with the $L^2$ distance when measured between updates, meaning that a fixed batch size will often give similar percentage errors. This is not true for the distance from initialization, in panel F; early optimization has higher variance relative to magnitude, meaning that more examples are needed for the same uncertainty. In the Appendix, we also display the convergence of the $L^2$ distance estimator between epochs.
  • Figure 4: Regularizing the $L^2$ distance from old tasks (calculated over a working memory cache of size 1024) can successfully prevent catastrophic forgetting. Here we display the test performance on the first task as 7 subsequent tasks are learned. Our method outperforms simply retraining on the same cache (ADAM+retrain), which potentially overfits to the cache. Also displayed are ADAM without modifications, EWC, and SI.
  • Figure 5: Results of a Squeezenet v1.1 trained on CIFAR10. The learning rate $\epsilon$ is decreased by a factor of 10 at epoch 150. For the train error we overlay the running average of each trace for clarity.
  • ...and 8 more figures