Table of Contents
Fetching ...

Attend First, Consolidate Later: On the Importance of Attention in Different LLM Layers

Amit Ben-Artzy, Roy Schwartz

TL;DR

The results hint at a two stage process in transformer-based LLMs: the first part gathers input from previous tokens, while the second mainly processes that information internally.

Abstract

In decoder-based LLMs, the representation of a given layer serves two purposes: as input to the next layer during the computation of the current token; and as input to the attention mechanism of future tokens. In this work, we show that the importance of the latter role might be overestimated. To show that, we start by manipulating the representations of previous tokens; e.g. by replacing the hidden states at some layer k with random vectors. Our experimenting with four LLMs and four tasks show that this operation often leads to small to negligible drop in performance. Importantly, this happens if the manipulation occurs in the top part of the model-k is in the final 30-50% of the layers. In contrast, doing the same manipulation in earlier layers might lead to chance level performance. We continue by switching the hidden state of certain tokens with hidden states of other tokens from another prompt; e.g., replacing the word "Italy" with "France" in "What is the capital of Italy?". We find that when applying this switch in the top 1/3 of the model, the model ignores it (answering "Rome"). However if we apply it before, the model conforms to the switch ("Paris"). Our results hint at a two stage process in transformer-based LLMs: the first part gathers input from previous tokens, while the second mainly processes that information internally.

Attend First, Consolidate Later: On the Importance of Attention in Different LLM Layers

TL;DR

The results hint at a two stage process in transformer-based LLMs: the first part gathers input from previous tokens, while the second mainly processes that information internally.

Abstract

In decoder-based LLMs, the representation of a given layer serves two purposes: as input to the next layer during the computation of the current token; and as input to the attention mechanism of future tokens. In this work, we show that the importance of the latter role might be overestimated. To show that, we start by manipulating the representations of previous tokens; e.g. by replacing the hidden states at some layer k with random vectors. Our experimenting with four LLMs and four tasks show that this operation often leads to small to negligible drop in performance. Importantly, this happens if the manipulation occurs in the top part of the model-k is in the final 30-50% of the layers. In contrast, doing the same manipulation in earlier layers might lead to chance level performance. We continue by switching the hidden state of certain tokens with hidden states of other tokens from another prompt; e.g., replacing the word "Italy" with "France" in "What is the capital of Italy?". We find that when applying this switch in the top 1/3 of the model, the model ignores it (answering "Rome"). However if we apply it before, the model conforms to the switch ("Paris"). Our results hint at a two stage process in transformer-based LLMs: the first part gathers input from previous tokens, while the second mainly processes that information internally.
Paper Structure (21 sections, 2 equations, 6 figures)

This paper contains 21 sections, 2 equations, 6 figures.

Figures (6)

  • Figure 1: To evaluate the role of previous hidden states as input to the attention mechanism, we devise two setups: (a) we replace the hidden state at layer $k$ with a random vector, and use it as input to layer $k+1$, which continues processing as normal; (b) starting from a given layer $k+1$, the hidden representations of previous tokens are frozen, and the attention mechanism attends to their hidden states at layer $k$.
  • Figure 2: Manipulating the history tokens of different LLMs on the capitals dataset across different layers. We observe that all models become robust to the freeze manipulation after about 15 layers ($\approx$$50\%$ of the layers), and to the other manipulations after about 20--25 layers.
  • Figure 3: Manipulation results on both versions of the math exercises dataset with the Llemma model. The model is highly resilient to the freeze manipulation starting layer 16 in both cases, while far less robust to the other manipulations.
  • Figure 4: The effect of different manipulations on the SQuAD dataset. The Llama2-7B model is resilient to all manipulations after 18 (freeze) to 27 (skip attention) layers. Interestingly, results improve over the baseline if these manipulation are applied later. Yi is resilient to freezing (25 layers), though not to the other manipulations. Mistral is not resilient to any manipulation.
  • Figure 5: The effect of different manipulations on LLM performance on the CNN/Daily Mail dataset. Yi-6B reaches baseline performance at top layers. As before, Llama2-7B is resilient to freezing starting layer 20. Other models reach similar performance, but still inferior to the baseline. All models are not resilient to the other manipulations.
  • ...and 1 more figures