Table of Contents
Fetching ...

Residual Connections and the Causal Shift: Uncovering a Structural Misalignment in Transformers

Jonathan Lys, Vincent Gripon, Bastien Pasdeloup, Lukas Mauch, Fabien Cardinaux, Ghouthi Boukli Hacene

TL;DR

This work addresses a fundamental misalignment in autoregressive Transformers: causal masking and residual connections create an input-output leakage where representations are anchored to the current token $t_i$ while predictions target the next token $t_{i+1}$. It empirically localizes this shift by tracking token alignment across depth using tied embeddings and the logit lens, revealing a deep transition to output-aligned representations (around layer $17$ in Gemma-2-2B, with similar patterns in other models). Building on this, the authors propose residual attenuation strategies, including a fixed-layer cut and a learnable gating mechanism (mixture-of-depth), implemented as $x_{l+1}=\alpha x_l+F_l(\mathrm{LN}(x_l))$ with $\alpha>0$. Across multiple benchmarks, the learned gating method consistently improves or matches baseline performance while reducing input-output misalignment, offering a robust, low-cost architectural improvement for autoregressive Transformers. Together, these results provide both a deeper understanding of depth-wise representation alignment and practical tools to enhance the efficiency and accuracy of large language models.

Abstract

Large Language Models (LLMs) are trained with next-token prediction, implemented in autoregressive Transformers via causal masking for parallelism. This creates a subtle misalignment: residual connections tie activations to the current token, while supervision targets the next token, potentially propagating mismatched information if the current token is not the most informative for prediction. In this work, we empirically localize this input-output alignment shift in pretrained LLMs, using decoding trajectories over tied embedding spaces and similarity-based metrics. Our experiments reveal that the hidden token representations switch from input alignment to output alignment deep within the network. Motivated by this observation, we propose a lightweight residual-path mitigation based on residual attenuation, implemented either as a fixed-layer intervention or as a learnable gating mechanism. Experiments on multiple benchmarks show that these strategies alleviate the representation misalignment and yield improvements, providing an efficient and general architectural enhancement for autoregressive Transformers.

Residual Connections and the Causal Shift: Uncovering a Structural Misalignment in Transformers

TL;DR

This work addresses a fundamental misalignment in autoregressive Transformers: causal masking and residual connections create an input-output leakage where representations are anchored to the current token while predictions target the next token . It empirically localizes this shift by tracking token alignment across depth using tied embeddings and the logit lens, revealing a deep transition to output-aligned representations (around layer in Gemma-2-2B, with similar patterns in other models). Building on this, the authors propose residual attenuation strategies, including a fixed-layer cut and a learnable gating mechanism (mixture-of-depth), implemented as with . Across multiple benchmarks, the learned gating method consistently improves or matches baseline performance while reducing input-output misalignment, offering a robust, low-cost architectural improvement for autoregressive Transformers. Together, these results provide both a deeper understanding of depth-wise representation alignment and practical tools to enhance the efficiency and accuracy of large language models.

Abstract

Large Language Models (LLMs) are trained with next-token prediction, implemented in autoregressive Transformers via causal masking for parallelism. This creates a subtle misalignment: residual connections tie activations to the current token, while supervision targets the next token, potentially propagating mismatched information if the current token is not the most informative for prediction. In this work, we empirically localize this input-output alignment shift in pretrained LLMs, using decoding trajectories over tied embedding spaces and similarity-based metrics. Our experiments reveal that the hidden token representations switch from input alignment to output alignment deep within the network. Motivated by this observation, we propose a lightweight residual-path mitigation based on residual attenuation, implemented either as a fixed-layer intervention or as a learnable gating mechanism. Experiments on multiple benchmarks show that these strategies alleviate the representation misalignment and yield improvements, providing an efficient and general architectural enhancement for autoregressive Transformers.
Paper Structure (6 sections, 1 equation, 5 figures, 2 tables)

This paper contains 6 sections, 1 equation, 5 figures, 2 tables.

Figures (5)

  • Figure 1: Schematic illustration of the token misalignment between input and output in Transformer architectures (here with a single layer). Token $t_i$ in the input is directly connected to token $t_{i+1}$ in the output through the residual connections.
  • Figure 2: Average top-5 match rate between the decoded hidden states and either the input sequence (blue) or the shifted input sequence (red) as a function of layer depth in the Gemma-2-2B model. Results are averaged over 1,000 sequences from the Wikitext dataset. The shift from input to output indexing occurs late in the architecture.
  • Figure 3: Continuous similarity measures between hidden states and their corresponding input and shifted input tokens, reported for Gemma-2-2B team2024gemma, Llama-3.2-3B grattafiori2024llama and Mistral-7B-v0.3 jiang2023mistral7b. Metrics include cosine similarity to input and output token embeddings, and normalized projection along the input–output axis (0: input token, 1: output token). Results confirm the latent transition from input-based to output-based representations across different architectures.
  • Figure 4: Impact of cutting the residual path at different layers in GPT2-0.1B, trained on 10B tokens from Fineweb penedo2024the on the validation loss. The baseline loss is represented with a dotted line.
  • Figure 5: Evolution of the learned probability distribution of cutting the residual path at a given layer, during training. At the beginning, all 12 layers have the same probability of cut, then the 11th layer attracts all the mass of the probability.