Table of Contents
Fetching ...

Sneaking Syntax into Transformer Language Models with Tree Regularization

Ananjan Nandi, Christopher D. Manning, Shikhar Murty

TL;DR

TreeReg introduces a differentiable regularizer that softly imposes hierarchical, constituency-based structure on transformer language models without changing their architecture. By defining a Span Contextual Independence Score (SCIN) that uses orthogonality constraints on prefix and span representations, TreeReg encourages constituent representations to be contextually independent, guiding the model toward tree-like computations. Across pre-training and fine-tuning scenarios, TreeReg yields improved syntactic generalization and substantial gains in out-of-distribution perplexity and robustness to adversarial NLI, while maintaining efficiency and enabling parse recovery from hidden states. The approach demonstrates that architecture-agnostic, regularized biases can enhance linguistic structure learning at scale, with notable sample-efficiency benefits and broad applicability to LMs and LLMs. The authors release code to facilitate replication and future exploration.

Abstract

While compositional accounts of human language understanding are based on a hierarchical tree-like process, neural models like transformers lack a direct inductive bias for such tree structures. Introducing syntactic inductive biases could unlock more robust and data-efficient learning in transformer language models (LMs), but existing methods for incorporating such structure greatly restrict models, either limiting their expressivity or increasing inference complexity. This work instead aims to softly inject syntactic inductive biases into given transformer circuits, through a structured regularizer. We introduce TreeReg, an auxiliary loss function that converts bracketing decisions from silver parses into a set of differentiable orthogonality constraints on vector hidden states. TreeReg integrates seamlessly with the standard LM objective, requiring no architectural changes. LMs pre-trained with TreeReg on natural language corpora such as WikiText-103 achieve up to 10% lower perplexities on out-of-distribution data and up to 9.5 point improvements in syntactic generalization, requiring less than half the training data to outperform standard LMs. TreeReg still provides gains for pre-trained LLMs: Continued pre-training of Sheared Llama with TreeReg results in improved syntactic generalization, and fine-tuning on MultiNLI with TreeReg mitigates degradation of performance on adversarial NLI benchmarks by 41.2 points. We release all code to guide future research.

Sneaking Syntax into Transformer Language Models with Tree Regularization

TL;DR

TreeReg introduces a differentiable regularizer that softly imposes hierarchical, constituency-based structure on transformer language models without changing their architecture. By defining a Span Contextual Independence Score (SCIN) that uses orthogonality constraints on prefix and span representations, TreeReg encourages constituent representations to be contextually independent, guiding the model toward tree-like computations. Across pre-training and fine-tuning scenarios, TreeReg yields improved syntactic generalization and substantial gains in out-of-distribution perplexity and robustness to adversarial NLI, while maintaining efficiency and enabling parse recovery from hidden states. The approach demonstrates that architecture-agnostic, regularized biases can enhance linguistic structure learning at scale, with notable sample-efficiency benefits and broad applicability to LMs and LLMs. The authors release code to facilitate replication and future exploration.

Abstract

While compositional accounts of human language understanding are based on a hierarchical tree-like process, neural models like transformers lack a direct inductive bias for such tree structures. Introducing syntactic inductive biases could unlock more robust and data-efficient learning in transformer language models (LMs), but existing methods for incorporating such structure greatly restrict models, either limiting their expressivity or increasing inference complexity. This work instead aims to softly inject syntactic inductive biases into given transformer circuits, through a structured regularizer. We introduce TreeReg, an auxiliary loss function that converts bracketing decisions from silver parses into a set of differentiable orthogonality constraints on vector hidden states. TreeReg integrates seamlessly with the standard LM objective, requiring no architectural changes. LMs pre-trained with TreeReg on natural language corpora such as WikiText-103 achieve up to 10% lower perplexities on out-of-distribution data and up to 9.5 point improvements in syntactic generalization, requiring less than half the training data to outperform standard LMs. TreeReg still provides gains for pre-trained LLMs: Continued pre-training of Sheared Llama with TreeReg results in improved syntactic generalization, and fine-tuning on MultiNLI with TreeReg mitigates degradation of performance on adversarial NLI benchmarks by 41.2 points. We release all code to guide future research.

Paper Structure

This paper contains 67 sections, 10 equations, 16 figures, 8 tables, 2 algorithms.

Figures (16)

  • Figure 1: $\textsc{TreeReg}{}$ loss ($\mathcal{L}_\text{TR}{}$) computation for $S =$ "he is very happy now". (i) Computation of vector hidden states $\bm{h}_i$ by passing $S$ as input to some circuit of the LM. $\bm{h}_i$ is the representation for the prefix of $S$ ending at $i$. (ii) Span Contextual Independence Score ($\mathrm{SCIN}{}$, § \ref{['sec: SCI']}) computation for "is very happy". Orthogonality constraints are enforced between span representation $\bm{h}_4$ and its context $\bm{h}_1$ and $\bm{h}_5$. (iii) Chart of $\mathrm{SCIN}{}$ for all spans in $S$. (iv) Possible bracketings of "is very happy" are ("is very", "happy") with score $\mathrm{SCIN}{}(2,3) + \mathrm{SCIN}{}(4,4)$ and ("is", "very happy") with score $\mathrm{SCIN}{}(2,2) + \mathrm{SCIN}{}(3,4)$. Loss for this span ($l_{(2,4)}$) encourages the second bracketing. $\mathcal{L}_\text{TR}{} = l_{(1,5)} + l_{(1,4)} + l_{(2,4)}$ includes analogous losses for spans "he is very happy" and "he is very happy now".
  • Figure 2: Comparing $\textsc{TreeReg}{}$ LM with Base LM from Table \ref{['tab:baselines']} on SG test suites. $\textsc{TreeReg}{}$ LM outperforms the Base LM on 4 out of 6 test suites, with 1 tie.
  • Figure 3: Plot of Syntactic Generalization on SG test suites vs Percentage of BLLIP-LG data used to train LMs from scratch. $\textsc{TreeReg}{}$ LM exceeds the maximum syntactic generalization performance of Base LM with less than 50% of the data.
  • Figure 4: Unlabeled F1 scores on the BLLIP-LG test set for parse trees induced from every layer of a 16-layer $\textsc{TreeReg}{}$ LM trained on BLLIP-LG with $\mathcal{L}_\text{TR}{}$ at layer 12. Circuits become increasingly tree-structured till layer 12, then rapidly become unstructured.
  • Figure 5: Syntactic Generalization on SG test suites vs Percentage of parsed BLLIP-LG data provided to $\textsc{TreeReg}{}$, for LMs trained from scratch on BLLIP-LG. Even with 1% of the data, $\textsc{TreeReg}{}$ LMs have better syntactic generalization than baseline LMs.
  • ...and 11 more figures