Scalable Optimization in the Modular Norm
Tim Large, Yang Liu, Minyoung Huh, Hyojin Bahng, Phillip Isola, Jeremy Bernstein
TL;DR
This paper introduces the modular norm, an architecture-aware, recursively defined norm on full neural-network weight spaces, enabling scalable training as width and depth grow. By normalizing updates in this norm, the authors achieve learning-rate transfer across scale and provide theoretical guarantees that gradients are Lipschitz in the modular norm. They implement Modula, a Python package that constructs modular norms and normalizes base optimizers, and demonstrate improved scale-up performance on GPT-like models and vision architectures. The work bridges practical optimization with theory by deriving sharpness and smoothness results in the modular-norm setting and offering a library to apply these ideas in real models. The proposed approach promises more stable, scalable training and invites further exploration of normed optimization and mass-allocation strategies.
Abstract
To improve performance in contemporary deep learning, one is interested in scaling up the neural network in terms of both the number and the size of the layers. When ramping up the width of a single layer, graceful scaling of training has been linked to the need to normalize the weights and their updates in the "natural norm" particular to that layer. In this paper, we significantly generalize this idea by defining the modular norm, which is the natural norm on the full weight space of any neural network architecture. The modular norm is defined recursively in tandem with the network architecture itself. We show that the modular norm has several promising applications. On the practical side, the modular norm can be used to normalize the updates of any base optimizer so that the learning rate becomes transferable across width and depth. This means that the user does not need to compute optimizer-specific scale factors in order to scale training. On the theoretical side, we show that for any neural network built from "well-behaved" atomic modules, the gradient of the network is Lipschitz-continuous in the modular norm, with the Lipschitz constant admitting a simple recursive formula. This characterization opens the door to porting standard ideas in optimization theory over to deep learning. We have created a Python package called Modula that automatically normalizes weight updates in the modular norm of the architecture. The package is available via "pip install modula" with source code at https://github.com/jxbz/modula.
