Table of Contents
Fetching ...

Transformer Layers as Painters

Qi Sun, Marc Pickett, Aakash Kumar Nain, Llion Jones

TL;DR

The paper investigates how information flows across transformer layers in frozen pretrained architectures using a painters analogy and systematic ablations. It finds that middle layers share a representation space and can be partially skipped or reordered with graceful degradation, while first and last layers are more specialized. Looping, randomization, and parallel execution provide latency-accuracy tradeoffs and improve robustness on benchmarks such as ARC, GSM8K, and GLUE. Across scales and architectures the results reveal a three-class layer taxonomy and suggest practical routing-based latency reductions that do not require fine-tuning.

Abstract

Despite their nearly universal adoption for large language models, the internal workings of transformers are not well understood. We aim to better understand the impact of removing or reorganizing information throughout the layers of a pretrained transformer. Such an understanding could both yield better usage of existing models as well as to make architectural improvements to produce new variants. We present a series of empirical studies on frozen models that show that the lower and final layers of pretrained transformers differ from middle layers, but that middle layers have a surprising amount of uniformity. We further show that some classes of problems have robustness to skipping layers, running the layers in an order different from how they were trained, or running the layers in parallel. Our observations suggest that even frozen pretrained models may gracefully trade accuracy for latency by skipping layers or running layers in parallel.

Transformer Layers as Painters

TL;DR

The paper investigates how information flows across transformer layers in frozen pretrained architectures using a painters analogy and systematic ablations. It finds that middle layers share a representation space and can be partially skipped or reordered with graceful degradation, while first and last layers are more specialized. Looping, randomization, and parallel execution provide latency-accuracy tradeoffs and improve robustness on benchmarks such as ARC, GSM8K, and GLUE. Across scales and architectures the results reveal a three-class layer taxonomy and suggest practical routing-based latency reductions that do not require fine-tuning.

Abstract

Despite their nearly universal adoption for large language models, the internal workings of transformers are not well understood. We aim to better understand the impact of removing or reorganizing information throughout the layers of a pretrained transformer. Such an understanding could both yield better usage of existing models as well as to make architectural improvements to produce new variants. We present a series of empirical studies on frozen models that show that the lower and final layers of pretrained transformers differ from middle layers, but that middle layers have a surprising amount of uniformity. We further show that some classes of problems have robustness to skipping layers, running the layers in an order different from how they were trained, or running the layers in parallel. Our observations suggest that even frozen pretrained models may gracefully trade accuracy for latency by skipping layers or running layers in parallel.
Paper Structure (18 sections, 16 figures)

This paper contains 18 sections, 16 figures.

Figures (16)

  • Figure 1: Different execution strategies.
  • Figure 2: Results for Open-LAMBADA from skipping layer $N$ (blue), and from switching layer $N$ with $N+1$ (green) of Llama2-7B. Skipping early layers has a catastrophic effect, while the model is much more robust to skipping middle layers.
  • Figure 3: Avg. cosine similarity between the hidden states of all 32 layers of Llama2-7B (top) and all 40 layers of Llama2-13B.
  • Figure 4: Top: Skipping layers N to 32-N for Llama2-7B, normalized per benchmark (median). Bottom: Skipping layers N to 24-N for BERT, with unnormalized average.
  • Figure 5: Replacing $M$ middle layers with the center layer (16 for Llama, 12 for BERT) for Llama2-7B (top, normalized benchmarks). and BERT (unnormalized average).
  • ...and 11 more figures