Table of Contents
Fetching ...

It Ain't That Bad: Understanding the Mysterious Performance Drop in OOD Generalization for Generative Transformer Models

Xingcheng Xu, Zihao Pan, Haipeng Zhang, Yanqing Yang

TL;DR

This work investigates why generative Transformer models exhibit strong in-distribution ($ID$) generalization on $n$-digit arithmetic yet poor out-of-distribution ($OOD$) generalization for longer inputs. By training small decoder-only Transformers on $n$-digit addition and multiplication, the authors reveal that $ID$ generalization relies on structured representations, while the $OOD$ gap reflects an equivalence-generalization mechanism based on learned residue-class relations modulo $p=10^3$. They characterize this behavior through algebraic (residue class) structures, probability distributions (greedy decoding), and representation dynamics (PCA of token embeddings), and show these patterns persist across model sizes and encoding schemes. The findings provide mechanistic insights into generalization and suggest priors for improving cross-length robustness and domain adaptation in large language models.

Abstract

Large language models (LLMs) have achieved remarkable proficiency on solving diverse problems. However, their generalization ability is not always satisfying and the generalization problem is common for generative transformer models in general. Researchers take basic mathematical tasks like n-digit addition or multiplication as important perspectives for investigating their generalization behaviors. It is observed that when training models on n-digit operations (e.g., additions) in which both input operands are n-digit in length, models generalize successfully on unseen n-digit inputs (in-distribution (ID) generalization), but fail miserably on longer, unseen cases (out-of-distribution (OOD) generalization). We bring this unexplained performance drop into attention and ask whether there is systematic OOD generalization. Towards understanding LLMs, we train various smaller language models which may share the same underlying mechanism. We discover that the strong ID generalization stems from structured representations, while behind the unsatisfying OOD performance, the models still exhibit clear learned algebraic structures. Specifically, these models map unseen OOD inputs to outputs with learned equivalence relations in the ID domain, which we call the equivalence generalization. These findings deepen our knowledge regarding the generalizability of generative models including LLMs, and provide insights into potential avenues for improvement.

It Ain't That Bad: Understanding the Mysterious Performance Drop in OOD Generalization for Generative Transformer Models

TL;DR

This work investigates why generative Transformer models exhibit strong in-distribution () generalization on -digit arithmetic yet poor out-of-distribution () generalization for longer inputs. By training small decoder-only Transformers on -digit addition and multiplication, the authors reveal that generalization relies on structured representations, while the gap reflects an equivalence-generalization mechanism based on learned residue-class relations modulo . They characterize this behavior through algebraic (residue class) structures, probability distributions (greedy decoding), and representation dynamics (PCA of token embeddings), and show these patterns persist across model sizes and encoding schemes. The findings provide mechanistic insights into generalization and suggest priors for improving cross-length robustness and domain adaptation in large language models.

Abstract

Large language models (LLMs) have achieved remarkable proficiency on solving diverse problems. However, their generalization ability is not always satisfying and the generalization problem is common for generative transformer models in general. Researchers take basic mathematical tasks like n-digit addition or multiplication as important perspectives for investigating their generalization behaviors. It is observed that when training models on n-digit operations (e.g., additions) in which both input operands are n-digit in length, models generalize successfully on unseen n-digit inputs (in-distribution (ID) generalization), but fail miserably on longer, unseen cases (out-of-distribution (OOD) generalization). We bring this unexplained performance drop into attention and ask whether there is systematic OOD generalization. Towards understanding LLMs, we train various smaller language models which may share the same underlying mechanism. We discover that the strong ID generalization stems from structured representations, while behind the unsatisfying OOD performance, the models still exhibit clear learned algebraic structures. Specifically, these models map unseen OOD inputs to outputs with learned equivalence relations in the ID domain, which we call the equivalence generalization. These findings deepen our knowledge regarding the generalizability of generative models including LLMs, and provide insights into potential avenues for improvement.
Paper Structure (18 sections, 9 equations, 5 figures, 3 tables)

This paper contains 18 sections, 9 equations, 5 figures, 3 tables.

Figures (5)

  • Figure 1: Training curves in addition and multiplication operations.
  • Figure 2: Contour plots for addition and multiplication operations.
  • Figure 3: The probability distribution of each digit of the sequence in an addition operation $c=a+b$. The left side of the black dashed line represents the input $a+b$, while the right side is the result $c$. Figure \ref{['fig:addition_probs']}(a) and Figure \ref{['fig:addition_probs']}(e) represent the $349+705$ and $128+256$, and the outputs are $1,054$ and $384$ ($450100$ and $483000$ in actual sequence output), respectively. In the second column, we perturb the thousands digit of $a$: Figure \ref{['fig:addition_probs']}(b) represents $1,349+705$, and Figure \ref{['fig:addition_probs']}(f) represents $3,128+256$. In the third column, we perturb the thousands digit of $b$: Figure \ref{['fig:addition_probs']}(c) represents $349+2,705$, and Figure \ref{['fig:addition_probs']}(g) represents $128+4,256$. In the fourth column, we simultaneously perturb the thousands digit of $a$ and $b$: Figure \ref{['fig:addition_probs']}(d) represents $1,349+2,705$, and Figure \ref{['fig:addition_probs']}(h) represents $3,128+4,256$.
  • Figure 4: 3D representation structure of the first three principal components in the addition operation. Figure \ref{['fig:addition_representation_pca_3d']}(a) to Figure \ref{['fig:addition_representation_pca_3d']}(d) represent the initial model, model with 14%, 51%, and 100% test accuracy, respectively.
  • Figure 5: The accuracy of OOD test on equivalence for different model and data scales