One Jump Is All You Need: Short-Cutting Transformers for Early Exit Prediction with One Jump to Fit All Exit Levels
Amrit Diggavi Seshadri
TL;DR
The paper tackles the high inference cost of deep transformers by enabling efficient early exit through a single low-rank shortcut. It introduces One-Jump-Fits-All (OJFA), selecting one jump via a Signed Sensitive Cosine Similarity criterion to serve all exit levels, built on the N-NJTC framework with $A_k$, $B_k$, and BatchNorm. OJFA achieves over $30\times$ reduction in shortcut parameters while largely preserving performance across GPT2-XL, Phi3-Mini, and Llama2-7B, closely matching multi-jump methods for many exits. It outperforms Identity shortcuts at early stages and demonstrates substantial practical efficiency gains, though some limitations and safety considerations remain for industrial deployment.
Abstract
To reduce the time and computational costs of inference of large language models, there has been interest in parameter-efficient low-rank early-exit casting of transformer hidden-representations to final-representations. Such low-rank short-cutting has been shown to outperform identity shortcuts at early model stages while offering parameter-efficiency in shortcut jumps. However, current low-rank methods maintain a separate early-exit shortcut jump to final-representations for each transformer intermediate block-level during inference. In this work, we propose selection of a single One-Jump-Fits-All (OJFA) low-rank shortcut that offers over a 30x reduction in shortcut parameter costs during inference. We show that despite this extreme reduction, our OJFA choice largely matches the performance of maintaining multiple shortcut jumps during inference and offers stable precision from all transformer block-levels for GPT2-XL, Phi3-Mini and Llama2-7B transformer models.
