Sneaking Syntax into Transformer Language Models with Tree Regularization
Ananjan Nandi, Christopher D. Manning, Shikhar Murty
TL;DR
TreeReg introduces a differentiable regularizer that softly imposes hierarchical, constituency-based structure on transformer language models without changing their architecture. By defining a Span Contextual Independence Score (SCIN) that uses orthogonality constraints on prefix and span representations, TreeReg encourages constituent representations to be contextually independent, guiding the model toward tree-like computations. Across pre-training and fine-tuning scenarios, TreeReg yields improved syntactic generalization and substantial gains in out-of-distribution perplexity and robustness to adversarial NLI, while maintaining efficiency and enabling parse recovery from hidden states. The approach demonstrates that architecture-agnostic, regularized biases can enhance linguistic structure learning at scale, with notable sample-efficiency benefits and broad applicability to LMs and LLMs. The authors release code to facilitate replication and future exploration.
Abstract
While compositional accounts of human language understanding are based on a hierarchical tree-like process, neural models like transformers lack a direct inductive bias for such tree structures. Introducing syntactic inductive biases could unlock more robust and data-efficient learning in transformer language models (LMs), but existing methods for incorporating such structure greatly restrict models, either limiting their expressivity or increasing inference complexity. This work instead aims to softly inject syntactic inductive biases into given transformer circuits, through a structured regularizer. We introduce TreeReg, an auxiliary loss function that converts bracketing decisions from silver parses into a set of differentiable orthogonality constraints on vector hidden states. TreeReg integrates seamlessly with the standard LM objective, requiring no architectural changes. LMs pre-trained with TreeReg on natural language corpora such as WikiText-103 achieve up to 10% lower perplexities on out-of-distribution data and up to 9.5 point improvements in syntactic generalization, requiring less than half the training data to outperform standard LMs. TreeReg still provides gains for pre-trained LLMs: Continued pre-training of Sheared Llama with TreeReg results in improved syntactic generalization, and fine-tuning on MultiNLI with TreeReg mitigates degradation of performance on adversarial NLI benchmarks by 41.2 points. We release all code to guide future research.
