Table of Contents
Fetching ...

On the Robustness of Transformers against Context Hijacking for Linear Classification

Tianle Li, Chenyang Zhang, Xingwu Chen, Yuan Cao, Difan Zou

TL;DR

This work analyzes the robustness of transformers to context hijacking by modeling in-context linear classification with multi-layer linear transformers. It shows that such transformers can implement multi-step gradient descent on context examples, and derives optimal initialization and step-size schedules that depend on the training context length $n$ and depth $L$. The results reveal that deeper models and longer training contexts yield finer optimization steps, reducing interference from hijacked context, with bounds and empirical validation supporting the theory. The findings provide theoretical justification for why deeper architectures can exhibit improved robustness to context hijacking and offer a framework transferable to other gradient-descent-based in-context learning problems.

Abstract

Transformer-based Large Language Models (LLMs) have demonstrated powerful in-context learning capabilities. However, their predictions can be disrupted by factually correct context, a phenomenon known as context hijacking, revealing a significant robustness issue. To understand this phenomenon theoretically, we explore an in-context linear classification problem based on recent advances in linear transformers. In our setup, context tokens are designed as factually correct query-answer pairs, where the queries are similar to the final query but have opposite labels. Then, we develop a general theoretical analysis on the robustness of the linear transformers, which is formulated as a function of the model depth, training context lengths, and number of hijacking context tokens. A key finding is that a well-trained deeper transformer can achieve higher robustness, which aligns with empirical observations. We show that this improvement arises because deeper layers enable more fine-grained optimization steps, effectively mitigating interference from context hijacking. This is also well supported by our numerical experiments. Our findings provide theoretical insights into the benefits of deeper architectures and contribute to enhancing the understanding of transformer architectures.

On the Robustness of Transformers against Context Hijacking for Linear Classification

TL;DR

This work analyzes the robustness of transformers to context hijacking by modeling in-context linear classification with multi-layer linear transformers. It shows that such transformers can implement multi-step gradient descent on context examples, and derives optimal initialization and step-size schedules that depend on the training context length and depth . The results reveal that deeper models and longer training contexts yield finer optimization steps, reducing interference from hijacked context, with bounds and empirical validation supporting the theory. The findings provide theoretical justification for why deeper architectures can exhibit improved robustness to context hijacking and offer a framework transferable to other gradient-descent-based in-context learning problems.

Abstract

Transformer-based Large Language Models (LLMs) have demonstrated powerful in-context learning capabilities. However, their predictions can be disrupted by factually correct context, a phenomenon known as context hijacking, revealing a significant robustness issue. To understand this phenomenon theoretically, we explore an in-context linear classification problem based on recent advances in linear transformers. In our setup, context tokens are designed as factually correct query-answer pairs, where the queries are similar to the final query but have opposite labels. Then, we develop a general theoretical analysis on the robustness of the linear transformers, which is formulated as a function of the model depth, training context lengths, and number of hijacking context tokens. A key finding is that a well-trained deeper transformer can achieve higher robustness, which aligns with empirical observations. We show that this improvement arises because deeper layers enable more fine-grained optimization steps, effectively mitigating interference from context hijacking. This is also well supported by our numerical experiments. Our findings provide theoretical insights into the benefits of deeper architectures and contribute to enhancing the understanding of transformer architectures.

Paper Structure

This paper contains 25 sections, 15 theorems, 76 equations, 6 figures.

Key Result

Proposition 4.1

For any $L$-layer single-head linear transformer, let $\widehat{y}_{\mathrm{query}}^{(l)}$ be the output of the $l$-th layer of the transformer, i.e. the $(d+1,n+1)$-th entry of $\mathbf{Z}_l$. Then, there exists a single-head linear transformer with $L$ layers such that $\widehat{y}_{\mathrm{query} Here $\bm{\Gamma}_{l}$ can be any $d\times d$ matrix.

Figures (6)

  • Figure 1: Context hijacking phenomenon in LLMs of different depths.Left: If there are no or only a few factually correct prepends, LLMs of different depths can correctly predict the next token. When the number of prepends increases, the outputs of models are disrupted, and deeper models are more robust. Right: Four different types of tasks are introduced, each with a fixed template, and tested on LLMs of different depths. The horizontal axis is the model with depth from small to large, and the vertical axis is the average number of prepends required to successfully interfere with the model output. Experiments show that deeper models perform more robustly. (Experimental setup is given in Appendix \ref{['appen:chinLLMs']})
  • Figure 2: Gradient descent experiments using a single-layer neural network. We use grid search to obtain the optimal learning rate for different training context lengths $n$ and different steps of gradient descent $L$. Then we use the corresponding optimal learning rate to perform multi-step gradient descent optimization on the test dataset. The results show that longer training context lengths and more gradient descent steps lead to smaller optimal learning rate and better optimization.
  • Figure 3: Linear transformers experiments with different depths and different training context lengths. By testing the trained linear transformers on the test set, we can find that as the number of interference samples increases, the model prediction accuracy becomes worse. However, deeper models have higher accuracy, indicating stronger robustness. As the training context length increases, the model robustness will also increase because the accuracy converges significantly more slowly.
  • Figure 4: Linear transformers experiments on training dataset. By testing trained linear transformers on the training set, the initial accuracy of the model is high and can be improved with the increase of context length, indicating that the model can use in-context learning to fine-tune $\bm{\beta}^\star$ to $\mathbf{w}^\star$. And deeper models have stronger optimization capabilities.
  • Figure 5: Standard transformers experiments with different depths. Testing the trained standard transformers (GPT-2 architecture radford2019language) on the test set, as the number of interference samples increases, the model classification accuracy decreases and gradually converges. The results also show that deeper models are more robust.
  • ...and 1 more figures

Theorems & Definitions (17)

  • Definition 3.1: Data distribution
  • Definition 3.2: Hijacked context data
  • Proposition 4.1
  • Proposition 4.2
  • Theorem 4.3
  • Lemma 4.4
  • Theorem 4.5
  • Lemma 2.1
  • Lemma 3.1
  • Theorem 3.2: Restate of Theorem \ref{['thm:training_loss_bound']}
  • ...and 7 more