Table of Contents
Fetching ...

Function-Space Learning Rates

Edward Milsom, Ben Anson, Laurence Aitchison

TL;DR

This work introduces layerwise function-space learning rates to quantify how parameter updates affect network outputs, not just parameter values, and provides an efficient Monte-Carlo estimator that requires only a single additional backward pass per estimate. A Kronecker-factorized covariance model further reduces variance, enabling scalable estimation in arbitrary architectures. The authors then propose FLeRM, which records base function-space updates from a small model and adjusts parameter-space learning rates in larger models to match those function-space changes, allowing hyperparameter transfer across width, depth, initialization scale, and LoRA rank. Empirical results on ResMLP and Transformer variants reveal time-varying, layer-structured function-space dynamics and demonstrate that FLeRM substantially improves hyperparameter transfer and training stability across scaling scenarios, offering a practical approach to tuning very large models.

Abstract

We consider layerwise function-space learning rates, which measure the magnitude of the change in a neural network's output function in response to an update to a parameter tensor. This contrasts with traditional learning rates, which describe the magnitude of changes in parameter space. We develop efficient methods to measure and set function-space learning rates in arbitrary neural networks, requiring only minimal computational overhead through a few additional backward passes that can be performed at the start of, or periodically during, training. We demonstrate two key applications: (1) analysing the dynamics of standard neural network optimisers in function space, rather than parameter space, and (2) introducing FLeRM (Function-space Learning Rate Matching), a novel approach to hyperparameter transfer across model scales. FLeRM records function-space learning rates while training a small, cheap base model, then automatically adjusts parameter-space layerwise learning rates when training larger models to maintain consistent function-space updates. FLeRM gives hyperparameter transfer across model width, depth, initialisation scale, and LoRA rank in various architectures including MLPs with residual connections and transformers with different layer normalisation schemes.

Function-Space Learning Rates

TL;DR

This work introduces layerwise function-space learning rates to quantify how parameter updates affect network outputs, not just parameter values, and provides an efficient Monte-Carlo estimator that requires only a single additional backward pass per estimate. A Kronecker-factorized covariance model further reduces variance, enabling scalable estimation in arbitrary architectures. The authors then propose FLeRM, which records base function-space updates from a small model and adjusts parameter-space learning rates in larger models to match those function-space changes, allowing hyperparameter transfer across width, depth, initialization scale, and LoRA rank. Empirical results on ResMLP and Transformer variants reveal time-varying, layer-structured function-space dynamics and demonstrate that FLeRM substantially improves hyperparameter transfer and training stability across scaling scenarios, offering a practical approach to tuning very large models.

Abstract

We consider layerwise function-space learning rates, which measure the magnitude of the change in a neural network's output function in response to an update to a parameter tensor. This contrasts with traditional learning rates, which describe the magnitude of changes in parameter space. We develop efficient methods to measure and set function-space learning rates in arbitrary neural networks, requiring only minimal computational overhead through a few additional backward passes that can be performed at the start of, or periodically during, training. We demonstrate two key applications: (1) analysing the dynamics of standard neural network optimisers in function space, rather than parameter space, and (2) introducing FLeRM (Function-space Learning Rate Matching), a novel approach to hyperparameter transfer across model scales. FLeRM records function-space learning rates while training a small, cheap base model, then automatically adjusts parameter-space layerwise learning rates when training larger models to maintain consistent function-space updates. FLeRM gives hyperparameter transfer across model width, depth, initialisation scale, and LoRA rank in various architectures including MLPs with residual connections and transformers with different layer normalisation schemes.

Paper Structure

This paper contains 33 sections, 35 equations, 26 figures, 1 algorithm.

Figures (26)

  • Figure 1: Function-space learning rates over time, measured using our approach, for the ResMLP model (top) and the Transformer (PostNorm) model (bottom). "QK Weights" refers to $W_\text{Q}$ or $W_\text{K}$ (query and key weight matrices), whilst "VO Weights" refers to $W_\text{V}$ or $W_\text{O}$ (values and head-concatenation projection weight matrices).
  • Figure 2: FLeRM dramatically improves optimal learning rate transfer across widths. Top: standard practice. Bottom: FLeRM.
  • Figure 3: FLeRM improves or maintains optimal learning rate transfer across depth. Top: standard practice. Bottom: (FLeRM).
  • Figure 4: FLeRM allows us to train initialisation scale invariant networks. Top: standard practice. Bottom: FLeRM.
  • Figure 5: FLeRM improves optimal learning rate transfer when changing LoRA rank. The plots show the behaviour of training loss under varying the learning rate of $B$ and LoRA rank for two continual pretraining tasks. Top: standard AdamW optimiser. Bottom: (FLeRM). Some lines end abruptly for larger learning rates, indicating a numerical instability.
  • ...and 21 more figures