Table of Contents
Fetching ...

Gradient Routing: Masking Gradients to Localize Computation in Neural Networks

Alex Cloud, Jacob Goldman-Wetzler, Evžen Wybitul, Joseph Miller, Alexander Matt Turner

TL;DR

Gradient routing modifies backpropagation with data-dependent gradient masks to localize learning to predefined network regions, enabling mechanistic supervision, robust unlearning, and scalable oversight without altering the loss. The method is demonstrated on MNIST and language-model tasks, along with ERA-based unlearning and scalable oversight in reinforcement learning, highlighting the absorption phenomenon where localized updates influence broader data. The work shows localization can be robust under limited labels and scales to larger models, offering a safety-focused tool for controlling internal mechanisms in complex AI systems.

Abstract

Neural networks are trained primarily based on their inputs and outputs, without regard for their internal mechanisms. These neglected mechanisms determine properties that are critical for safety, like (i) transparency; (ii) the absence of sensitive information or harmful capabilities; and (iii) reliable generalization of goals beyond the training distribution. To address this shortcoming, we introduce gradient routing, a training method that isolates capabilities to specific subregions of a neural network. Gradient routing applies data-dependent, weighted masks to gradients during backpropagation. These masks are supplied by the user in order to configure which parameters are updated by which data points. We show that gradient routing can be used to (1) learn representations which are partitioned in an interpretable way; (2) enable robust unlearning via ablation of a pre-specified network subregion; and (3) achieve scalable oversight of a reinforcement learner by localizing modules responsible for different behaviors. Throughout, we find that gradient routing localizes capabilities even when applied to a limited, ad-hoc subset of the data. We conclude that the approach holds promise for challenging, real-world applications where quality data are scarce.

Gradient Routing: Masking Gradients to Localize Computation in Neural Networks

TL;DR

Gradient routing modifies backpropagation with data-dependent gradient masks to localize learning to predefined network regions, enabling mechanistic supervision, robust unlearning, and scalable oversight without altering the loss. The method is demonstrated on MNIST and language-model tasks, along with ERA-based unlearning and scalable oversight in reinforcement learning, highlighting the absorption phenomenon where localized updates influence broader data. The work shows localization can be robust under limited labels and scales to larger models, offering a safety-focused tool for controlling internal mechanisms in complex AI systems.

Abstract

Neural networks are trained primarily based on their inputs and outputs, without regard for their internal mechanisms. These neglected mechanisms determine properties that are critical for safety, like (i) transparency; (ii) the absence of sensitive information or harmful capabilities; and (iii) reliable generalization of goals beyond the training distribution. To address this shortcoming, we introduce gradient routing, a training method that isolates capabilities to specific subregions of a neural network. Gradient routing applies data-dependent, weighted masks to gradients during backpropagation. These masks are supplied by the user in order to configure which parameters are updated by which data points. We show that gradient routing can be used to (1) learn representations which are partitioned in an interpretable way; (2) enable robust unlearning via ablation of a pre-specified network subregion; and (3) achieve scalable oversight of a reinforcement learner by localizing modules responsible for different behaviors. Throughout, we find that gradient routing localizes capabilities even when applied to a limited, ad-hoc subset of the data. We conclude that the approach holds promise for challenging, real-world applications where quality data are scarce.
Paper Structure (30 sections, 11 equations, 14 figures, 5 tables)

This paper contains 30 sections, 11 equations, 14 figures, 5 tables.

Figures (14)

  • Figure 1: Gradient routing applies weighted masks to selectively block or re-weight gradients during backpropagation. By supplying different masks for different data, the user can induce specialization in network subregions. The figure shows three masks, which would correspond to three data points.
  • Figure 2: Example of gradient routing implemented in PyTorch. For each batch of training data points , a batch of corresponding to those data points is passed as well. The method applies the stop-gradient operator, preventing gradients from being backpropagated through but leaving its value unchanged.
  • Figure 3: Gradient routing induces a clean split in the encodings of a simple MLP autoencoder trained on MNIST digits. By applying data-dependent stop-gradients and L1 regularization, the top half of the encoding comes to represent digits 0--4 only, and the bottom half of the encoding comes to represent digits 5--9 only.
  • Figure 4: Backpropagation in the Route step of Expand-Route-Ablate, showing the flow of gradients through a Transformer for tokens in the forget set. This assumes a learning rate of zero for the original dimensions in target layers. Gradients for retain tokens are unmodified. Additional dimensions, shown with dashed outlines, were added to target layers in the MLP and attention blocks, and will be removed after training in the Ablate step. All modules participate in the forward pass.
  • Figure 5: Effect of unlearning methods on forget and retain validation loss depending on the proportion of forget samples labeled. Highlighted regions denote 95% C.I. for the mean across at least $N=5$ training runs. Left: how much each method increases forget loss after it is applied. For ERA and DEMix + ablate, this is pre- vs. post-ablation. Center: how much forget loss increases after a method is applied and the model is fine-tuned on 64 forget stories. (The minimum validation forget loss over fine-tuning is reported.) Right: the retain set performance after applying each method. Note: we include an additional data point for RMU at 0.95 of forget stories labeled. We also include a point for ERA+RMU (denoted with a "+") at full labeling.
  • ...and 9 more figures