Table of Contents
Fetching ...

Mechanistic Interpretability of Binary and Ternary Transformers

Jason Li

TL;DR

This paper investigates whether binarized and ternarized transformer networks provide interpretability advantages over full-precision models. It applies mechanistic interpretability to the discrete toy problem of modular addition, reverse-engineering the learned algorithms and comparing Fourier/clock-like representations to those of full-precision networks, while examining grokking dynamics. The study contributes as the first to apply mechanistic interpretability to binary/ternary transformers, showing that these discretized models tend to learn algorithms similar to full-precision ones (with some added noise) rather than simpler, more interpretable strategies. The findings suggest that discretization alone does not inherently yield more interpretable algorithms in this setting, motivating future work on other tasks, optimization techniques, and fully binarized/ternarized architectures to better assess interpretability benefits.

Abstract

Recent research (arXiv:2310.11453, arXiv:2402.17764) has proposed binary and ternary transformer networks as a way to significantly reduce memory and improve inference speed in Large Language Models (LLMs) while maintaining accuracy. In this work, we apply techniques from mechanistic interpretability to investigate whether such networks learn distinctly different or similar algorithms when compared to full-precision transformer networks. In particular, we reverse engineer the algorithms learned for the toy problem of modular addition where we find that binary and ternary networks learn similar algorithms as full precision networks. This provides evidence against the possibility of using binary and ternary networks as a more interpretable alternative in the LLM setting.

Mechanistic Interpretability of Binary and Ternary Transformers

TL;DR

This paper investigates whether binarized and ternarized transformer networks provide interpretability advantages over full-precision models. It applies mechanistic interpretability to the discrete toy problem of modular addition, reverse-engineering the learned algorithms and comparing Fourier/clock-like representations to those of full-precision networks, while examining grokking dynamics. The study contributes as the first to apply mechanistic interpretability to binary/ternary transformers, showing that these discretized models tend to learn algorithms similar to full-precision ones (with some added noise) rather than simpler, more interpretable strategies. The findings suggest that discretization alone does not inherently yield more interpretable algorithms in this setting, motivating future work on other tasks, optimization techniques, and fully binarized/ternarized architectures to better assess interpretability benefits.

Abstract

Recent research (arXiv:2310.11453, arXiv:2402.17764) has proposed binary and ternary transformer networks as a way to significantly reduce memory and improve inference speed in Large Language Models (LLMs) while maintaining accuracy. In this work, we apply techniques from mechanistic interpretability to investigate whether such networks learn distinctly different or similar algorithms when compared to full-precision transformer networks. In particular, we reverse engineer the algorithms learned for the toy problem of modular addition where we find that binary and ternary networks learn similar algorithms as full precision networks. This provides evidence against the possibility of using binary and ternary networks as a more interpretable alternative in the LLM setting.
Paper Structure (9 sections, 4 equations, 6 figures)

This paper contains 9 sections, 4 equations, 6 figures.

Figures (6)

  • Figure 1: The curve showing training loss and test loss for 10,000 epochs. The train loss converges to around 0.006, and the test loss converges to around 0.01.
  • Figure 2: The norms of the Fourier components of the embedding matrix. Observe that there are a few key frequencies which stand out.
  • Figure 3: (Top-left) The attention score for head 0 from token '=' to token $a$ as a function of inputs $a,b$. (Top-right) The activations of MLP neuron 0. (Bottom) The norm of the Fourier components of logits. All of these correspond to the binary model.
  • Figure 4: For a ternary model: (Left) Loss curve. The train and test loss converge to around 0.03 for both. (Right) The norms of the Fourier components of the embedding matrix.
  • Figure 5: The fraction of variance explained by degree-2 polynomials of a single frequency.
  • ...and 1 more figures