Lambda-Skip Connections: the architectural component that prevents Rank Collapse
Federico Arangath Joseph, Jerome Sieber, Melanie N. Zeilinger, Carmen Amo Alonso
TL;DR
This paper studies rank collapse in deep sequence models and extends the transformer-centric theory to State Space Models within a unifying framework. It introduces lambda-skip connections with a tunable strength and proves a sufficient condition under which rank collapse is prevented across Transformers and SSM variants, leveraging LayerNorm to decouple from input. The authors show necessity via ablations and provide analytical counterexamples illustrating rank collapse without skip connections or under suboptimal lambda choices, complemented by experiments on Mamba/Mamba-2 architectures that also underscore the role of gating and LayerNorm. The work offers theoretical guarantees and practical guidance for robust, expressive sequence models, highlighting how architectural choices shape stability and representational capacity."
Abstract
Rank collapse, a phenomenon where embedding vectors in sequence models rapidly converge to a uniform token or equilibrium state, has recently gained attention in the deep learning literature. This phenomenon leads to reduced expressivity and potential training instabilities due to vanishing gradients. Empirical evidence suggests that architectural components like skip connections, LayerNorm, and MultiLayer Perceptrons (MLPs) play critical roles in mitigating rank collapse. While this issue is well-documented for transformers, alternative sequence models, such as State Space Models (SSMs), which have recently gained prominence, have not been thoroughly examined for similar vulnerabilities. This paper extends the theory of rank collapse from transformers to SSMs using a unifying framework that captures both architectures. We study how a parametrized version of the classic skip connection component, which we call \emph{lambda-skip connections}, provides guarantees for rank collapse prevention. Through analytical results, we present a sufficient condition to guarantee prevention of rank collapse across all the aforementioned architectures. We also study the necessity of this condition via ablation studies and analytical examples. To our knowledge, this is the first study that provides a general guarantee to prevent rank collapse, and that investigates rank collapse in the context of SSMs, offering valuable understanding for both theoreticians and practitioners. Finally, we validate our findings with experiments demonstrating the crucial role of architectural components such as skip connections and gating mechanisms in preventing rank collapse.
