Table of Contents
Fetching ...

Lambda-Skip Connections: the architectural component that prevents Rank Collapse

Federico Arangath Joseph, Jerome Sieber, Melanie N. Zeilinger, Carmen Amo Alonso

TL;DR

This paper studies rank collapse in deep sequence models and extends the transformer-centric theory to State Space Models within a unifying framework. It introduces lambda-skip connections with a tunable strength and proves a sufficient condition under which rank collapse is prevented across Transformers and SSM variants, leveraging LayerNorm to decouple from input. The authors show necessity via ablations and provide analytical counterexamples illustrating rank collapse without skip connections or under suboptimal lambda choices, complemented by experiments on Mamba/Mamba-2 architectures that also underscore the role of gating and LayerNorm. The work offers theoretical guarantees and practical guidance for robust, expressive sequence models, highlighting how architectural choices shape stability and representational capacity."

Abstract

Rank collapse, a phenomenon where embedding vectors in sequence models rapidly converge to a uniform token or equilibrium state, has recently gained attention in the deep learning literature. This phenomenon leads to reduced expressivity and potential training instabilities due to vanishing gradients. Empirical evidence suggests that architectural components like skip connections, LayerNorm, and MultiLayer Perceptrons (MLPs) play critical roles in mitigating rank collapse. While this issue is well-documented for transformers, alternative sequence models, such as State Space Models (SSMs), which have recently gained prominence, have not been thoroughly examined for similar vulnerabilities. This paper extends the theory of rank collapse from transformers to SSMs using a unifying framework that captures both architectures. We study how a parametrized version of the classic skip connection component, which we call \emph{lambda-skip connections}, provides guarantees for rank collapse prevention. Through analytical results, we present a sufficient condition to guarantee prevention of rank collapse across all the aforementioned architectures. We also study the necessity of this condition via ablation studies and analytical examples. To our knowledge, this is the first study that provides a general guarantee to prevent rank collapse, and that investigates rank collapse in the context of SSMs, offering valuable understanding for both theoreticians and practitioners. Finally, we validate our findings with experiments demonstrating the crucial role of architectural components such as skip connections and gating mechanisms in preventing rank collapse.

Lambda-Skip Connections: the architectural component that prevents Rank Collapse

TL;DR

This paper studies rank collapse in deep sequence models and extends the transformer-centric theory to State Space Models within a unifying framework. It introduces lambda-skip connections with a tunable strength and proves a sufficient condition under which rank collapse is prevented across Transformers and SSM variants, leveraging LayerNorm to decouple from input. The authors show necessity via ablations and provide analytical counterexamples illustrating rank collapse without skip connections or under suboptimal lambda choices, complemented by experiments on Mamba/Mamba-2 architectures that also underscore the role of gating and LayerNorm. The work offers theoretical guarantees and practical guidance for robust, expressive sequence models, highlighting how architectural choices shape stability and representational capacity."

Abstract

Rank collapse, a phenomenon where embedding vectors in sequence models rapidly converge to a uniform token or equilibrium state, has recently gained attention in the deep learning literature. This phenomenon leads to reduced expressivity and potential training instabilities due to vanishing gradients. Empirical evidence suggests that architectural components like skip connections, LayerNorm, and MultiLayer Perceptrons (MLPs) play critical roles in mitigating rank collapse. While this issue is well-documented for transformers, alternative sequence models, such as State Space Models (SSMs), which have recently gained prominence, have not been thoroughly examined for similar vulnerabilities. This paper extends the theory of rank collapse from transformers to SSMs using a unifying framework that captures both architectures. We study how a parametrized version of the classic skip connection component, which we call \emph{lambda-skip connections}, provides guarantees for rank collapse prevention. Through analytical results, we present a sufficient condition to guarantee prevention of rank collapse across all the aforementioned architectures. We also study the necessity of this condition via ablation studies and analytical examples. To our knowledge, this is the first study that provides a general guarantee to prevent rank collapse, and that investigates rank collapse in the context of SSMs, offering valuable understanding for both theoreticians and practitioners. Finally, we validate our findings with experiments demonstrating the crucial role of architectural components such as skip connections and gating mechanisms in preventing rank collapse.

Paper Structure

This paper contains 43 sections, 21 theorems, 51 equations, 6 figures, 1 table.

Key Result

Theorem 4.1

Let the input sequence $Y^{(0)}$ be such that $\mu({Y^{(0)}})^2 \geq b.$ If the skip connection strength $\lambda$ is chosen to satisfy then we can lower bound $\mu(Y^{(K)})$ by $\mu(Y^{(K)})^2 \geq a^K\mu(Y^{(0)})^2$ for all $K\in\mathbb N$.

Figures (6)

  • Figure 1: (Normalized) Rank Collapse measure plotted for different values of the skip connection's strength $\lambda$. Again, shaded areas represent one standard deviation from the mean calculated over the 32 examples.
  • Figure 1: Model performance in terms of test accuracy on the Image task of the LRA benchmark and the MQAR task $\{ L=512, \textup{KV-pairs} = 64 \}$.
  • Figure 2: (Normalized) Rank Collapse measure at the last layer plotted for different values of $\lambda$ for both Albert and Mamba. Shaded areas represent one standard deviation from the mean calculated over the 32 examples.
  • Figure 3: (Normalized) Rank Collapse measure of the Mamba-2 model plotted as a function of layer depth. The shaded areas represent one standard deviation from the mean calculated over the 32 examples. Gating=True/False indicates we use a Mamba2 architecture with/without gating whereas LN=True/False indicate we use the architecture with/without LayerNorm
  • Figure 4: (Normalized) Rank Collapse measure as a function of layer depth for the S4D model. The shaded areas represent one standard deviation from the mean calculated over the 32 examples. Skip=True/False indicates we use a S4D architecture with/without skip connection whereas LayerNorm=True/False indicate we use a S4D architecture with/without LayerNorm
  • ...and 1 more figures

Theorems & Definitions (42)

  • Definition 3.1
  • Definition 4.1
  • Theorem 4.1: Lower Bound on Rank Collapse
  • proof
  • Remark 4.1
  • Theorem 4.2: Corollary 1, wu2024roleattentionmaskslayernorm (Informal)
  • Theorem 4.3: Rank Collapse for selective SSMs without skip connection
  • proof
  • Proposition 4.3.1
  • proof
  • ...and 32 more