Table of Contents
Fetching ...

A Transformer with Stack Attention

Jiaoda Li, Jennifer C. White, Mrinmaya Sachan, Ryan Cotterell

TL;DR

This work proposes a differentiable stack-based attention that can be plugged into transformers to enhance their capacity to learn deterministic context-free languages. By emulating a stack over the input index set and inserting a stack-attention sub-layer within each transformer layer, the approach yields improved performance on several CF tasks (notably reverse strings and stack manipulation) while maintaining interpretability through stack-top attention maps. The results also reveal limitations, showing that modular arithmetic tasks remain challenging and that the stack augmentation provides diminishing benefits as training data grows, suggesting the method offers an inductive bias useful in data-scarce settings but not a universal CF-language solution. Overall, the paper presents a modular, interpretable augmentation that extends transformer expressivity for hierarchical structure while outlining practical and theoretical limitations and directions for future work.

Abstract

Natural languages are believed to be (mildly) context-sensitive. Despite underpinning remarkably capable large language models, transformers are unable to model many context-free language tasks. In an attempt to address this limitation in the modeling power of transformer-based language models, we propose augmenting them with a differentiable, stack-based attention mechanism. Our stack-based attention mechanism can be incorporated into any transformer-based language model and adds a level of interpretability to the model. We show that the addition of our stack-based attention mechanism enables the transformer to model some, but not all, deterministic context-free languages.

A Transformer with Stack Attention

TL;DR

This work proposes a differentiable stack-based attention that can be plugged into transformers to enhance their capacity to learn deterministic context-free languages. By emulating a stack over the input index set and inserting a stack-attention sub-layer within each transformer layer, the approach yields improved performance on several CF tasks (notably reverse strings and stack manipulation) while maintaining interpretability through stack-top attention maps. The results also reveal limitations, showing that modular arithmetic tasks remain challenging and that the stack augmentation provides diminishing benefits as training data grows, suggesting the method offers an inductive bias useful in data-scarce settings but not a universal CF-language solution. Overall, the paper presents a modular, interpretable augmentation that extends transformer expressivity for hierarchical structure while outlining practical and theoretical limitations and directions for future work.

Abstract

Natural languages are believed to be (mildly) context-sensitive. Despite underpinning remarkably capable large language models, transformers are unable to model many context-free language tasks. In an attempt to address this limitation in the modeling power of transformer-based language models, we propose augmenting them with a differentiable, stack-based attention mechanism. Our stack-based attention mechanism can be incorporated into any transformer-based language model and adds a level of interpretability to the model. We show that the addition of our stack-based attention mechanism enables the transformer to model some, but not all, deterministic context-free languages.
Paper Structure (50 sections, 4 theorems, 24 equations, 5 figures, 3 tables)

This paper contains 50 sections, 4 theorems, 24 equations, 5 figures, 3 tables.

Key Result

Theorem 3.1

Let ${\color{black} \upsilon}_1, \ldots, {\color{black} \upsilon}_N$ be a series of stack operations where ${\color{black} \upsilon}_i \in \{{\color{black} \texttt{PUSH}}_i(\cdot), {\color{black} \texttt{NO-OP}}(\cdot), {\color{black} \texttt{POP}}(\cdot)\}$ for all $i \in [N]$. Furthermore, suppose

Figures (5)

  • Figure 1: An example illustrating how attentions can emulate stacks. The first column lists the operation performed at each timestep. The second column presents the stack contents after performing the operation. The third column shows a hard attention over the input tokens. The pointer of the attention indicates the current stack top. The last column is the proposed stack attention.
  • Figure 2: Stack attention maps at different layers for RS. The input ${\color{black} \boldsymbol{x}}$ is ${\color{black} \texttt{a}} {\color{black} \texttt{b}} {\color{black} \texttt{b}} {\color{black} \texttt{a}} {\color{black} \texttt{a}}$. ${\color{black} \texttt{M}}$ represents a ${\color{black} \textsc{[mask]}}$ token.
  • Figure 3: Stack attention maps at different layers for SM. The input ${\color{black} \boldsymbol{x}}$ is ${\color{black} \texttt{a}} {\color{black} \texttt{b}} {\color{black} [\textsc{pop}]} {\color{black} [\textsc{push}\ {\color{black} \texttt{a}}]} {\color{black} [\textsc{push}\ {\color{black} \texttt{b}}]}$. In the graphs, ${\color{black} [\textsc{push}\ {\color{black} \texttt{a}}]}$, ${\color{black} [\textsc{push}\ {\color{black} \texttt{b}}]}$, and ${\color{black} [\textsc{pop}]}$ are abbreviated as ${\color{black} \texttt{a}}$, ${\color{black} \texttt{b}}$, and ${\color{black} \texttt{P}}$ respectively. ${\color{black} \texttt{M}}$ represents a ${\color{black} \textsc{[mask]}}$ token. The correct output should be ${\color{black} \texttt{b}} {\color{black} \texttt{b}} {\color{black} \texttt{a}}$ followed by ${\color{black} \textsc{[pad]}}$ tokens.
  • Figure 4: Stack attention maps at different layers for MA. The input ${\color{black} \boldsymbol{x}}$ is $((4)\cdot(-0))=$.
  • Figure 5: Stack attention maps at different layers for SE. The input ${\color{black} \boldsymbol{x}}$ is $(1+(-z))=3$. ${\color{black} \texttt{M}}$ represents a ${\color{black} \textsc{[mask]}}$ token.

Theorems & Definitions (8)

  • Theorem 3.1
  • proof
  • Theorem 3.2
  • proof
  • Theorem A.1
  • proof
  • Theorem A.1
  • proof