Gradient Routing: Masking Gradients to Localize Computation in Neural Networks
Alex Cloud, Jacob Goldman-Wetzler, Evžen Wybitul, Joseph Miller, Alexander Matt Turner
TL;DR
Gradient routing modifies backpropagation with data-dependent gradient masks to localize learning to predefined network regions, enabling mechanistic supervision, robust unlearning, and scalable oversight without altering the loss. The method is demonstrated on MNIST and language-model tasks, along with ERA-based unlearning and scalable oversight in reinforcement learning, highlighting the absorption phenomenon where localized updates influence broader data. The work shows localization can be robust under limited labels and scales to larger models, offering a safety-focused tool for controlling internal mechanisms in complex AI systems.
Abstract
Neural networks are trained primarily based on their inputs and outputs, without regard for their internal mechanisms. These neglected mechanisms determine properties that are critical for safety, like (i) transparency; (ii) the absence of sensitive information or harmful capabilities; and (iii) reliable generalization of goals beyond the training distribution. To address this shortcoming, we introduce gradient routing, a training method that isolates capabilities to specific subregions of a neural network. Gradient routing applies data-dependent, weighted masks to gradients during backpropagation. These masks are supplied by the user in order to configure which parameters are updated by which data points. We show that gradient routing can be used to (1) learn representations which are partitioned in an interpretable way; (2) enable robust unlearning via ablation of a pre-specified network subregion; and (3) achieve scalable oversight of a reinforcement learner by localizing modules responsible for different behaviors. Throughout, we find that gradient routing localizes capabilities even when applied to a limited, ad-hoc subset of the data. We conclude that the approach holds promise for challenging, real-world applications where quality data are scarce.
