Table of Contents
Fetching ...

Adaptive Large Language Models By Layerwise Attention Shortcuts

Prateek Verma, Mert Pilanci

TL;DR

The work addresses inflexible depth in decoder-only transformers by enabling adaptive computations through attention shortcuts that let the final layer attend to intermediate-layer representations. It introduces learned MLP-based feature maps on selected depths and a cross-attention mechanism in the last layer over concatenated intermediate features under a causal mask. Across speech, text, and music datasets, the approach yields lower next-token NLL and demonstrates, via attention maps, that the model adaptively allocates computation across depth depending on input complexity. This depth- and context-adaptive strategy promises more efficient pre-training and scalable performance, with ablations and scaling analyses guiding future integration with other efficiency techniques.

Abstract

Transformer architectures are the backbone of the modern AI revolution. However, they are based on simply stacking the same blocks in dozens of layers and processing information sequentially from one block to another. In this paper, we propose to challenge this and introduce adaptive computations for LLM-like setups, which allow the final layer to attend to all of the intermediate layers as it deems fit through the attention mechanism, thereby introducing computational \textbf{attention shortcuts}. These shortcuts can thus make the architecture depth and context adaptive. We showcase four different datasets, namely acoustic tokens, natural language, and symbolic music, and we achieve superior performance for GPT-like architecture. We give evidence via attention maps that the models learn complex dependencies across layers that are adaptive in context and depth depending on the input tokens.

Adaptive Large Language Models By Layerwise Attention Shortcuts

TL;DR

The work addresses inflexible depth in decoder-only transformers by enabling adaptive computations through attention shortcuts that let the final layer attend to intermediate-layer representations. It introduces learned MLP-based feature maps on selected depths and a cross-attention mechanism in the last layer over concatenated intermediate features under a causal mask. Across speech, text, and music datasets, the approach yields lower next-token NLL and demonstrates, via attention maps, that the model adaptively allocates computation across depth depending on input complexity. This depth- and context-adaptive strategy promises more efficient pre-training and scalable performance, with ablations and scaling analyses guiding future integration with other efficiency techniques.

Abstract

Transformer architectures are the backbone of the modern AI revolution. However, they are based on simply stacking the same blocks in dozens of layers and processing information sequentially from one block to another. In this paper, we propose to challenge this and introduce adaptive computations for LLM-like setups, which allow the final layer to attend to all of the intermediate layers as it deems fit through the attention mechanism, thereby introducing computational \textbf{attention shortcuts}. These shortcuts can thus make the architecture depth and context adaptive. We showcase four different datasets, namely acoustic tokens, natural language, and symbolic music, and we achieve superior performance for GPT-like architecture. We give evidence via attention maps that the models learn complex dependencies across layers that are adaptive in context and depth depending on the input tokens.
Paper Structure (9 sections, 2 equations, 3 figures, 3 tables)

This paper contains 9 sections, 2 equations, 3 figures, 3 tables.

Figures (3)

  • Figure 1: (Left) Typical Transformer LLM (Middle) Our proposed Adaptive LLM that attends to intermediate layer embedding, allowing us to learn adaptive shortcuts both in context and depth. When predicting the last token, model's final layer can now attend to embeddings(and derived features) at different depths and contexts as it deems fit. The dark curved connectors attend to the second layer, and the dotted connectors attend to the first layer while decoding from the fourth layer. E.g. for the last token in the final layer, it can develop a connection in green to learn a shortcut directly from the 2nd layer second token via attention (Right). Our network is dense with more connections allowed than the vanilla Transformer.
  • Figure 2: Results of four different datasets for symbolic music, speech tokens and natural langauge.
  • Figure 3: Learned attention maps show how we can adaptively attend to any intermediate layer depending on input.