Residual Connections and the Causal Shift: Uncovering a Structural Misalignment in Transformers
Jonathan Lys, Vincent Gripon, Bastien Pasdeloup, Lukas Mauch, Fabien Cardinaux, Ghouthi Boukli Hacene
TL;DR
This work addresses a fundamental misalignment in autoregressive Transformers: causal masking and residual connections create an input-output leakage where representations are anchored to the current token $t_i$ while predictions target the next token $t_{i+1}$. It empirically localizes this shift by tracking token alignment across depth using tied embeddings and the logit lens, revealing a deep transition to output-aligned representations (around layer $17$ in Gemma-2-2B, with similar patterns in other models). Building on this, the authors propose residual attenuation strategies, including a fixed-layer cut and a learnable gating mechanism (mixture-of-depth), implemented as $x_{l+1}=\alpha x_l+F_l(\mathrm{LN}(x_l))$ with $\alpha>0$. Across multiple benchmarks, the learned gating method consistently improves or matches baseline performance while reducing input-output misalignment, offering a robust, low-cost architectural improvement for autoregressive Transformers. Together, these results provide both a deeper understanding of depth-wise representation alignment and practical tools to enhance the efficiency and accuracy of large language models.
Abstract
Large Language Models (LLMs) are trained with next-token prediction, implemented in autoregressive Transformers via causal masking for parallelism. This creates a subtle misalignment: residual connections tie activations to the current token, while supervision targets the next token, potentially propagating mismatched information if the current token is not the most informative for prediction. In this work, we empirically localize this input-output alignment shift in pretrained LLMs, using decoding trajectories over tied embedding spaces and similarity-based metrics. Our experiments reveal that the hidden token representations switch from input alignment to output alignment deep within the network. Motivated by this observation, we propose a lightweight residual-path mitigation based on residual attenuation, implemented either as a fixed-layer intervention or as a learnable gating mechanism. Experiments on multiple benchmarks show that these strategies alleviate the representation misalignment and yield improvements, providing an efficient and general architectural enhancement for autoregressive Transformers.
