Table of Contents
Fetching ...

Easing Optimization Paths: a Circuit Perspective

Ambroise Odonnat, Wassim Bouaziz, Vivien Cabannes

TL;DR

The paper addresses understanding how gradient descent navigates internal computations in deep networks by adopting a circuit perspective from mechanistic interpretability. It demonstrates this in a controlled sparse modular addition task over $\mathbb{F}_p$ using a one-layer Transformer with cross-attention, showing that updates reinforce useful circuits while pruning spurious ones. A key finding is that learning proceeds through the emergence of sub-circuits that perform intermediate steps (e.g., modular sums before final modulo), and curriculum learning or careful data curation can expose and compose these sub-circuits to achieve generalization rather than memorization. These insights suggest practical strategies for efficient training and reliability in large models, with code at the authors' GitHub repository.

Abstract

Gradient descent is the method of choice for training large artificial intelligence systems. As these systems become larger, a better understanding of the mechanisms behind gradient training would allow us to alleviate compute costs and help steer these systems away from harmful behaviors. To that end, we suggest utilizing the circuit perspective brought forward by mechanistic interpretability. After laying out our intuition, we illustrate how it enables us to design a curriculum for efficient learning in a controlled setting. The code is available at \url{https://github.com/facebookresearch/pal}.

Easing Optimization Paths: a Circuit Perspective

TL;DR

The paper addresses understanding how gradient descent navigates internal computations in deep networks by adopting a circuit perspective from mechanistic interpretability. It demonstrates this in a controlled sparse modular addition task over using a one-layer Transformer with cross-attention, showing that updates reinforce useful circuits while pruning spurious ones. A key finding is that learning proceeds through the emergence of sub-circuits that perform intermediate steps (e.g., modular sums before final modulo), and curriculum learning or careful data curation can expose and compose these sub-circuits to achieve generalization rather than memorization. These insights suggest practical strategies for efficient training and reliability in large models, with code at the authors' GitHub repository.

Abstract

Gradient descent is the method of choice for training large artificial intelligence systems. As these systems become larger, a better understanding of the mechanisms behind gradient training would allow us to alleviate compute costs and help steer these systems away from harmful behaviors. To that end, we suggest utilizing the circuit perspective brought forward by mechanistic interpretability. After laying out our intuition, we illustrate how it enables us to design a curriculum for efficient learning in a controlled setting. The code is available at \url{https://github.com/facebookresearch/pal}.
Paper Structure (11 sections, 2 equations, 7 figures)

This paper contains 11 sections, 2 equations, 7 figures.

Figures (7)

  • Figure 1: Analogy between neural networks and electrical circuits, different components routing the electric/information flows. A center-tapped full wave rectifier can be implemented as a 2-layer neural network with ReLU activations. Red and blue arrows represent respectfully +1 and -1 weights. The light blue represents the bias term.
  • Figure 2: Visualization of $(s_t) = \operatorname{softmax}((x_t + p\_t)^\top W_Q / \sqrt{d})$, the attention scores for a fixed sequence made up of $(x_t) = 3$, displayed at different points during training. The strength of the attention is visualized through the thickness of the arrows, while the color indicates the sign of the last updates: red for arrows that have just been thickened, blue for those that have been thinned.
  • Figure 3: Evolution of the attention weights through gradient descent. Each line corresponds to a training iteration and each row corresponds to an entry $x_t$ of the input sequence $x$. The darker, the higher the attention weight. Ultimately, the transformer learns to focus solely on the first $k=5$ input tokens, which are the ones defining the output $y$, indicated by the red vertical line. More exactly, it focuses on the $0$ among these tokens, before counting them and deducing the number of $1$ to make its final prediction.
  • Figure 4: Representation on the plan of the $d=2$ dimensional embeddings $z_A$ obtained after the attention module (see Section \ref{['sec:intro']}). Colors represent the sum of the $k$ first tokens. Left: after pretraining with $p=2$, we observe the emergence of equivalence classes modulo $6$. After finetuning with $p=4$, equivalence classes modulo $16$ appear.
  • Figure 5: Evolution of the train and test accuracy along the training iterations. $p = 4$ corresponds to the model trained from scratch with $p=4$ and $p=2 \rightarrow 4$ is the model first pretrained with $p=2$ and then finetuned with $p=4$. The red dashed line indicates the iteration at which we switch from the pretraining to the finetuning.
  • ...and 2 more figures