Function-Space Learning Rates
Edward Milsom, Ben Anson, Laurence Aitchison
TL;DR
This work introduces layerwise function-space learning rates to quantify how parameter updates affect network outputs, not just parameter values, and provides an efficient Monte-Carlo estimator that requires only a single additional backward pass per estimate. A Kronecker-factorized covariance model further reduces variance, enabling scalable estimation in arbitrary architectures. The authors then propose FLeRM, which records base function-space updates from a small model and adjusts parameter-space learning rates in larger models to match those function-space changes, allowing hyperparameter transfer across width, depth, initialization scale, and LoRA rank. Empirical results on ResMLP and Transformer variants reveal time-varying, layer-structured function-space dynamics and demonstrate that FLeRM substantially improves hyperparameter transfer and training stability across scaling scenarios, offering a practical approach to tuning very large models.
Abstract
We consider layerwise function-space learning rates, which measure the magnitude of the change in a neural network's output function in response to an update to a parameter tensor. This contrasts with traditional learning rates, which describe the magnitude of changes in parameter space. We develop efficient methods to measure and set function-space learning rates in arbitrary neural networks, requiring only minimal computational overhead through a few additional backward passes that can be performed at the start of, or periodically during, training. We demonstrate two key applications: (1) analysing the dynamics of standard neural network optimisers in function space, rather than parameter space, and (2) introducing FLeRM (Function-space Learning Rate Matching), a novel approach to hyperparameter transfer across model scales. FLeRM records function-space learning rates while training a small, cheap base model, then automatically adjusts parameter-space layerwise learning rates when training larger models to maintain consistent function-space updates. FLeRM gives hyperparameter transfer across model width, depth, initialisation scale, and LoRA rank in various architectures including MLPs with residual connections and transformers with different layer normalisation schemes.
