Table of Contents
Fetching ...

Learning Syntax Without Planting Trees: Understanding Hierarchical Generalization in Transformers

Kabir Ahuja, Vidhisha Balachandran, Madhur Panwar, Tianxing He, Noah A. Smith, Navin Goyal, Yulia Tsvetkov

TL;DR

This paper interrogates why transformers generalize hierarchically without explicit structural biases, focusing on the role of training objectives, dataset ambiguity, and representational subnetworks. It shows that language modeling objectives uniquely promote hierarchical generalization across diverse synthetic tasks, while other objectives often fail to do so. By pruning attention heads, the authors reveal coexisting subnetworks corresponding to hierarchical and linear generalizations, with the linear subnetwork emerging earlier and the hierarchical one later in training, and demonstrate that disambiguating data can suppress the linear pathway. A Bayesian PCFG-based analysis links hierarchical generalization to simpler, data-fit grammars, offering a principled explanation for why language modeling yields hierarchical bias. Collectively, the work illuminates how inductive biases in LM training, together with data structure, shape the emergence of hierarchical syntax in transformers, with implications for designing training regimens and interpreting model behavior.

Abstract

Transformers trained on natural language data have been shown to learn its hierarchical structure and generalize to sentences with unseen syntactic structures without explicitly encoding any structural bias. In this work, we investigate sources of inductive bias in transformer models and their training that could cause such generalization behavior to emerge. We extensively experiment with transformer models trained on multiple synthetic datasets and with different training objectives and show that while other objectives e.g. sequence-to-sequence modeling, prefix language modeling, often failed to lead to hierarchical generalization, models trained with the language modeling objective consistently learned to generalize hierarchically. We then conduct pruning experiments to study how transformers trained with the language modeling objective encode hierarchical structure. When pruned, we find joint existence of subnetworks within the model with different generalization behaviors (subnetworks corresponding to hierarchical structure and linear order). Finally, we take a Bayesian perspective to further uncover transformers' preference for hierarchical generalization: We establish a correlation between whether transformers generalize hierarchically on a dataset and whether the simplest explanation of that dataset is provided by a hierarchical grammar compared to regular grammars exhibiting linear generalization.

Learning Syntax Without Planting Trees: Understanding Hierarchical Generalization in Transformers

TL;DR

This paper interrogates why transformers generalize hierarchically without explicit structural biases, focusing on the role of training objectives, dataset ambiguity, and representational subnetworks. It shows that language modeling objectives uniquely promote hierarchical generalization across diverse synthetic tasks, while other objectives often fail to do so. By pruning attention heads, the authors reveal coexisting subnetworks corresponding to hierarchical and linear generalizations, with the linear subnetwork emerging earlier and the hierarchical one later in training, and demonstrate that disambiguating data can suppress the linear pathway. A Bayesian PCFG-based analysis links hierarchical generalization to simpler, data-fit grammars, offering a principled explanation for why language modeling yields hierarchical bias. Collectively, the work illuminates how inductive biases in LM training, together with data structure, shape the emergence of hierarchical syntax in transformers, with implications for designing training regimens and interpreting model behavior.

Abstract

Transformers trained on natural language data have been shown to learn its hierarchical structure and generalize to sentences with unseen syntactic structures without explicitly encoding any structural bias. In this work, we investigate sources of inductive bias in transformer models and their training that could cause such generalization behavior to emerge. We extensively experiment with transformer models trained on multiple synthetic datasets and with different training objectives and show that while other objectives e.g. sequence-to-sequence modeling, prefix language modeling, often failed to lead to hierarchical generalization, models trained with the language modeling objective consistently learned to generalize hierarchically. We then conduct pruning experiments to study how transformers trained with the language modeling objective encode hierarchical structure. When pruned, we find joint existence of subnetworks within the model with different generalization behaviors (subnetworks corresponding to hierarchical structure and linear order). Finally, we take a Bayesian perspective to further uncover transformers' preference for hierarchical generalization: We establish a correlation between whether transformers generalize hierarchically on a dataset and whether the simplest explanation of that dataset is provided by a hierarchical grammar compared to regular grammars exhibiting linear generalization.
Paper Structure (63 sections, 4 equations, 20 figures, 3 tables)

This paper contains 63 sections, 4 equations, 20 figures, 3 tables.

Figures (20)

  • Figure 1: Effect of training objective on hierarchical generalization in transformers. The error bars correspond to the standard errors across 5 random seeds. Only the language modeling objective consistently obtains high generalization accuracy on all tasks.
  • Figure 2: Training an RNN (GRU) using language modeling and seq2seq objectives on the question formation task. 300k training steps correspond to 24 epochs (or passes through the training data).
  • Figure 3: Pruning a transformer LM trained for 15000 steps using the three methods. Dark blocks mean that the head is pruned and light means it is kept.
  • Figure 4: Tracking training dynamics with respect to the three pruning methods's subnetworks and the full network. (a) and (b): in-distribution and generalization accuracies of the LMs trained on the original ambiguous question formation data after pruning using the three methods, (c): generalization accuracy after pruning the model trained on disambiguated data. For models trained with original data, we can discover sub-networks consistent with hierarchical rule as well as the linear rule, while for the models trained with disambiguated data, linear rule subnetwork is not found (indicated by the curve corresponding to $\texttt{Train} \backslash \texttt{Gen}$-prune never approaching 0% generalization accuracy).
  • Figure 5: Performance of transformer models trained on the $\mathcal{D}_{\mathrm{train-L}}$ and $\mathcal{D}_{\mathrm{train-S}}$ datasets.
  • ...and 15 more figures