Table of Contents
Fetching ...

StableSSM: Alleviating the Curse of Memory in State-space Models through Stable Reparameterization

Shida Wang, Qianxiao Li

TL;DR

StableSSM addresses long-term memory in sequence modeling by proving that state-space models without reparameterization inherit a memory curse similar to RNNs, restricting stable approximation to targets with exponential memory. It introduces a class of stable reparameterizations that lifts memory limitations and yields improved optimization stability, including a principled 'best' parameterization that balances gradient scales. The approach is validated on synthetic tasks, language modeling with WikiText-103, image classification, and Long Range Arena benchmarks, offering a theoretical and practical framework for designing memory-capable, stable SSMs. Overall, stable reparameterization not only enables stable learning of decaying-memory targets but also enhances training stability for large-scale sequence models.

Abstract

In this paper, we investigate the long-term memory learning capabilities of state-space models (SSMs) from the perspective of parameterization. We prove that state-space models without any reparameterization exhibit a memory limitation similar to that of traditional RNNs: the target relationships that can be stably approximated by state-space models must have an exponential decaying memory. Our analysis identifies this "curse of memory" as a result of the recurrent weights converging to a stability boundary, suggesting that a reparameterization technique can be effective. To this end, we introduce a class of reparameterization techniques for SSMs that effectively lift its memory limitations. Besides improving approximation capabilities, we further illustrate that a principled choice of reparameterization scheme can also enhance optimization stability. We validate our findings using synthetic datasets, language models and image classifications.

StableSSM: Alleviating the Curse of Memory in State-space Models through Stable Reparameterization

TL;DR

StableSSM addresses long-term memory in sequence modeling by proving that state-space models without reparameterization inherit a memory curse similar to RNNs, restricting stable approximation to targets with exponential memory. It introduces a class of stable reparameterizations that lifts memory limitations and yields improved optimization stability, including a principled 'best' parameterization that balances gradient scales. The approach is validated on synthetic tasks, language modeling with WikiText-103, image classification, and Long Range Arena benchmarks, offering a theoretical and practical framework for designing memory-capable, stable SSMs. Overall, stable reparameterization not only enables stable learning of decaying-memory targets but also enhances training stability for large-scale sequence models.

Abstract

In this paper, we investigate the long-term memory learning capabilities of state-space models (SSMs) from the perspective of parameterization. We prove that state-space models without any reparameterization exhibit a memory limitation similar to that of traditional RNNs: the target relationships that can be stably approximated by state-space models must have an exponential decaying memory. Our analysis identifies this "curse of memory" as a result of the recurrent weights converging to a stability boundary, suggesting that a reparameterization technique can be effective. To this end, we introduce a class of reparameterization techniques for SSMs that effectively lift its memory limitations. Besides improving approximation capabilities, we further illustrate that a principled choice of reparameterization scheme can also enhance optimization stability. We validate our findings using synthetic datasets, language models and image classifications.
Paper Structure (40 sections, 11 theorems, 82 equations, 7 figures, 7 tables)

This paper contains 40 sections, 11 theorems, 82 equations, 7 figures, 7 tables.

Key Result

Theorem 3.3

Assume $\mathbf{H}$ is a sequence of bounded, causal, continuous, regular and time-homogeneous functionals on $\mathcal{X}$ with decaying memory. Suppose there exists a sequence of state-space models $\{\widehat{\mathbf{H}}(\cdot, \theta_m)\}_{m=1}^{\infty}$$\beta_0$-stably approximating $\mathbf{H} Here $d$ is the dimension of input sequences. When generalized to multi-layer cases, the memory fun

Figures (7)

  • Figure 1: State-space models without stable reparameterization cannot approximate targets with polynomial decaying memory. In (a), the intersection of lines are shifting towards left as the hidden dimension $m$ increases. In (b), SSMs using softplus reparameterization has a stable approximation. In (c), S4 can stably approximate the target with better stability.
  • Figure 2: The scaling of layer output bound $|\hat{y}| \leq \frac{c}{1-\lambda}$ and the gradients $|\frac{d \hat{y}}{d\lambda}| \leq \frac{c}{(1-\lambda)^2}$. The stability boundary is $\lambda=\pm 1$. When the model adapts to learn long-term memory (as $\lambda$ approaches 1), the gradient experiences an increase that surpasses the rate of output growth. Techniques like layer normalization are insufficient to address this issue of exploding gradients effectively.
  • Figure 3: In panel (a), in the learning of linear functionals of polynomial decaying memory, the gradient-over-weight scale range during the training of state-space models. It can be seen the "best"discrete parameterization $f(w) = 1 - \frac{1}{w^2 + 0.5}$ achieves the smallest gradient-over-weight scale. Such property is desirable when a large learning rate is used in training. The "best" reparameterization $f(w)=1-\frac{1}{w^2+0.5}$ maintains the smallest $\max(\frac{|\textrm{grad}|}{|\textrm{weight}|})$ which is crucial for the training stability. Similar results can be observed in the language modelling task as in panel (b).
  • Figure 4: Language models on WikiText-103. In the left panel (a), we show the gradient-over-weight ratio ranges for different parameterizations of recurrent weights in state-space models. The eigenvalues $\lambda$ are initialized to be the same while the only difference is the reparameterization function $f$. In the right panel (b), the "Best" parameterization is more stable than the ReLU and exponential reparameterizations. Additional experiments for different learning rates are provided in \ref{['fig:1LR2LR5LR']}.
  • Figure 5: MLP can be realized by two-layer state-space models. The superscript indicates the layers while the subscript indicates the time index. It can be seen MLP is equivalent to SSMs having zero recurrent weights $W_1=W_2=0$.
  • ...and 2 more figures

Theorems & Definitions (36)

  • Definition 2.1: Memory function
  • Definition 2.2: Decaying memory
  • Definition 2.3: Functional sequence approximation in Sobolev-type norm
  • Definition 2.4: Perturbation error
  • Definition 2.5: Stable approximation
  • Theorem 3.3: Curse of memory in SSMs
  • Definition 3.4: Stable reparameterization
  • Theorem 3.5: Existence of stable approximation by stable reparameterization
  • Theorem 3.6: Parameterizations influence the gradient norm scale
  • Remark 3.7: Generalization to multi-layer models
  • ...and 26 more