Table of Contents
Fetching ...

Counting and Algorithmic Generalization with Transformers

Simon Ouellette, Rolf Pfister, Hansueli Jud

TL;DR

This work examines algorithmic generalization through counting tasks and shows that standard Transformers fail to generalize out-of-distribution due to architectural choices like softmax-attention and layer normalization. By ablating these components, the authors demonstrate a lightweight transformer that learns counting in a parallel fashion and generalizes to much larger grids. The study also reveals that layer normalization tends to tether models to training distributions, hindering generalization in both counting and identity tasks. Collectively, the results highlight the need to reassess normalization and attention mechanisms when applying Transformers to non-NLP, algorithmic tasks and offer a concrete path toward better generalization with minimal architectural changes.

Abstract

Algorithmic generalization in machine learning refers to the ability to learn the underlying algorithm that generates data in a way that generalizes out-of-distribution. This is generally considered a difficult task for most machine learning algorithms. Here, we analyze algorithmic generalization when counting is required, either implicitly or explicitly. We show that standard Transformers are based on architectural decisions that hinder out-of-distribution performance for such tasks. In particular, we discuss the consequences of using layer normalization and of normalizing the attention weights via softmax. With ablation of the problematic operations, we demonstrate that a modified transformer can exhibit a good algorithmic generalization performance on counting while using a very lightweight architecture.

Counting and Algorithmic Generalization with Transformers

TL;DR

This work examines algorithmic generalization through counting tasks and shows that standard Transformers fail to generalize out-of-distribution due to architectural choices like softmax-attention and layer normalization. By ablating these components, the authors demonstrate a lightweight transformer that learns counting in a parallel fashion and generalizes to much larger grids. The study also reveals that layer normalization tends to tether models to training distributions, hindering generalization in both counting and identity tasks. Collectively, the results highlight the need to reassess normalization and attention mechanisms when applying Transformers to non-NLP, algorithmic tasks and offer a concrete path toward better generalization with minimal architectural changes.

Abstract

Algorithmic generalization in machine learning refers to the ability to learn the underlying algorithm that generates data in a way that generalizes out-of-distribution. This is generally considered a difficult task for most machine learning algorithms. Here, we analyze algorithmic generalization when counting is required, either implicitly or explicitly. We show that standard Transformers are based on architectural decisions that hinder out-of-distribution performance for such tasks. In particular, we discuss the consequences of using layer normalization and of normalizing the attention weights via softmax. With ablation of the problematic operations, we demonstrate that a modified transformer can exhibit a good algorithmic generalization performance on counting while using a very lightweight architecture.
Paper Structure (24 sections, 5 equations, 9 figures, 3 tables, 1 algorithm)

This paper contains 24 sections, 5 equations, 9 figures, 3 tables, 1 algorithm.

Figures (9)

  • Figure 1: Generalization performance (%) comparison of different models
  • Figure 2: Histograms of count predictions (left) vs ground truths (right) for 6x6 grids (LayerNorm-FF-Count)
  • Figure 3: Histograms of count predictions (left) vs ground truths (right) for 15x15 grids (LayerNorm-FF-Count)
  • Figure 4: Histograms of std. deviations for successful predictions (left) vs failed predictions (right) count values on 12x12 grids (LayerNorm-FF-Count)
  • Figure 5: Histograms of count predictions (left) vs ground truths (right) for 6x6 grids (LayerNorm-SA-Count)
  • ...and 4 more figures