Adaptive Large Language Models By Layerwise Attention Shortcuts
Prateek Verma, Mert Pilanci
TL;DR
The work addresses inflexible depth in decoder-only transformers by enabling adaptive computations through attention shortcuts that let the final layer attend to intermediate-layer representations. It introduces learned MLP-based feature maps on selected depths and a cross-attention mechanism in the last layer over concatenated intermediate features under a causal mask. Across speech, text, and music datasets, the approach yields lower next-token NLL and demonstrates, via attention maps, that the model adaptively allocates computation across depth depending on input complexity. This depth- and context-adaptive strategy promises more efficient pre-training and scalable performance, with ablations and scaling analyses guiding future integration with other efficiency techniques.
Abstract
Transformer architectures are the backbone of the modern AI revolution. However, they are based on simply stacking the same blocks in dozens of layers and processing information sequentially from one block to another. In this paper, we propose to challenge this and introduce adaptive computations for LLM-like setups, which allow the final layer to attend to all of the intermediate layers as it deems fit through the attention mechanism, thereby introducing computational \textbf{attention shortcuts}. These shortcuts can thus make the architecture depth and context adaptive. We showcase four different datasets, namely acoustic tokens, natural language, and symbolic music, and we achieve superior performance for GPT-like architecture. We give evidence via attention maps that the models learn complex dependencies across layers that are adaptive in context and depth depending on the input tokens.
