Table of Contents
Fetching ...

The Remarkable Robustness of LLMs: Stages of Inference?

Vedang Lad, Jin Hwa Lee, Wes Gurnee, Max Tegmark

TL;DR

The paper investigates how decoder-only LLMs remain robust under layer deletions and adjacent-layer swaps, revealing nonuniform depth-dependent effects. It proposes a universal four-stage inference framework—detokenization, feature engineering, prediction ensembling, and residual sharpening—to interpret depth-wise computations across model families. Through layer interventions and targeted probes (e.g., WiC probing, logit lens, CKA analyses), it shows early and late layers are most sensitive while middle layers display notable resilience, supported by neuron-level analyses of prediction and suppression ensembles. The work provides a cohesive perspective on how redundancy and residual pathways enable self-repair and ensembling, with broad implications for interpretability, auditing, and robust model design across varied transformer architectures.

Abstract

We investigate the robustness of Large Language Models (LLMs) to structural interventions by deleting and swapping adjacent layers during inference. Surprisingly, models retain 72-95% of their original top-1 prediction accuracy without any fine-tuning. We find that performance degradation is not uniform across layers: interventions to the early and final layers cause the most degradation, while the model is remarkably robust to dropping middle layers. This pattern of localized sensitivity motivates our hypothesis of four stages of inference, observed across diverse model families and sizes: (1) detokenization, where local context is integrated to lift raw token embeddings into higher-level representations; (2) feature engineering, where task- and entity-specific features are iteratively refined; (3) prediction ensembling, where hidden states are aggregated into plausible next-token predictions; and (4) residual sharpening, where irrelevant features are suppressed to finalize the output distribution. Synthesizing behavioral and mechanistic evidence, we provide a framework for interpreting depth-dependent computations in LLMs.

The Remarkable Robustness of LLMs: Stages of Inference?

TL;DR

The paper investigates how decoder-only LLMs remain robust under layer deletions and adjacent-layer swaps, revealing nonuniform depth-dependent effects. It proposes a universal four-stage inference framework—detokenization, feature engineering, prediction ensembling, and residual sharpening—to interpret depth-wise computations across model families. Through layer interventions and targeted probes (e.g., WiC probing, logit lens, CKA analyses), it shows early and late layers are most sensitive while middle layers display notable resilience, supported by neuron-level analyses of prediction and suppression ensembles. The work provides a cohesive perspective on how redundancy and residual pathways enable self-repair and ensembling, with broad implications for interpretability, auditing, and robust model design across varied transformer architectures.

Abstract

We investigate the robustness of Large Language Models (LLMs) to structural interventions by deleting and swapping adjacent layers during inference. Surprisingly, models retain 72-95% of their original top-1 prediction accuracy without any fine-tuning. We find that performance degradation is not uniform across layers: interventions to the early and final layers cause the most degradation, while the model is remarkably robust to dropping middle layers. This pattern of localized sensitivity motivates our hypothesis of four stages of inference, observed across diverse model families and sizes: (1) detokenization, where local context is integrated to lift raw token embeddings into higher-level representations; (2) feature engineering, where task- and entity-specific features are iteratively refined; (3) prediction ensembling, where hidden states are aggregated into plausible next-token predictions; and (4) residual sharpening, where irrelevant features are suppressed to finalize the output distribution. Synthesizing behavioral and mechanistic evidence, we provide a framework for interpreting depth-dependent computations in LLMs.
Paper Structure (50 sections, 4 equations, 21 figures, 3 tables)

This paper contains 50 sections, 4 equations, 21 figures, 3 tables.

Figures (21)

  • Figure 1: Statistical signatures of universal stages of inference across three model families. (Blue) KL between the normal model and layer $\ell$ zero-ablated. (Purple) Total attention paid to the previous five tokens in a sequence. (Green) The number of “prediction” neurons (Red) The number of suppression neurons geva2020transformervoita2023neuronsgurnee2024universal.
  • Figure 2: (a) Effect of layer swap (top) and layer drop (bottom) interventions on model behavior. (left) KL divergence between the intervened and original models. (right) Consistency of top-1 predictions. (b)(c) Representational similarity across layers measured using CKA, showing block-like structure in GPT-2 XL (b) and Pythia 2.8B (c). Similar trends are observed across other model families and sizes (see Appendix \ref{['app:cka']}).
  • Figure 3: (a) The average (across heads within a layer and query tokens) attention weight placed on the preceding 1, 2, 4, 8, 16 tokens for each layer. (b) Attention from the source token to the final token in various inputs. An identified sub-joiner attention head (bottom) found in the early layers of language models is responsible for attending to multi-token words (i.e, shenanigans, refurbishments, parfaitement, circumnavigate), compared to the baseline set of random non-multi-token words (top).
  • Figure 3: List of models and dataset used in the experiments.
  • Figure 4: (a) Layer-wise probe accuracy on contextual lexical meaning (WiC task), peaking in intermediate layers is suggestive of where semantic features are linearly encoded. (b) Using the logit lens technique nostalgebraist2020interpreting, we calculate the probability distribution of the next token at the end of every layer, and then take its entropy providing a measure of the model's confidence in the next prediction. Despite high probe accuracy, the residual, but high entropic residual stream suggests that semantic features exist mid-model but are not yet used for prediction. For all models see Appendix \ref{['app:wic2_fig2']} and \ref{['fig:4sharpall']}.
  • ...and 16 more figures