Table of Contents
Fetching ...

Towards a theory of learning dynamics in deep state space models

Jakub Smékal, Jimmy T. H. Smith, Michael Kleinman, Dan Biderman, Scott W. Linderman

TL;DR

The paper addresses the theoretical gap in learning dynamics for state-space models by analyzing gradient-descent training of linear SSMs on a squared loss, focusing on how data covariance and latent size shape parameter evolution. By transforming the SSM into the Fourier domain, it derives analytic continuous-time dynamics for a one-layer case and connects these dynamics to the learning behavior of deep linear feed-forward networks, using sufficient statistics $\sigma$ and $\eta$ to describe covariances. It shows that over-parameterization (larger latent size $N$) can accelerate convergence, with time constants scaling as $O\left(\frac{\tau}{N\sigma}\right)$ in balanced setups and $O\left(\frac{\tau}{N^2\eta}\right)$ under other regimes, linking SSM learning dynamics to established deep-network theory. These results provide a principled step toward a comprehensive theory of learning dynamics in deep state-space models and motivate extensions to multi-layer and nonlinear SSMs.

Abstract

State space models (SSMs) have shown remarkable empirical performance on many long sequence modeling tasks, but a theoretical understanding of these models is still lacking. In this work, we study the learning dynamics of linear SSMs to understand how covariance structure in data, latent state size, and initialization affect the evolution of parameters throughout learning with gradient descent. We show that focusing on the learning dynamics in the frequency domain affords analytical solutions under mild assumptions, and we establish a link between one-dimensional SSMs and the dynamics of deep linear feed-forward networks. Finally, we analyze how latent state over-parameterization affects convergence time and describe future work in extending our results to the study of deep SSMs with nonlinear connections. This work is a step toward a theory of learning dynamics in deep state space models.

Towards a theory of learning dynamics in deep state space models

TL;DR

The paper addresses the theoretical gap in learning dynamics for state-space models by analyzing gradient-descent training of linear SSMs on a squared loss, focusing on how data covariance and latent size shape parameter evolution. By transforming the SSM into the Fourier domain, it derives analytic continuous-time dynamics for a one-layer case and connects these dynamics to the learning behavior of deep linear feed-forward networks, using sufficient statistics and to describe covariances. It shows that over-parameterization (larger latent size ) can accelerate convergence, with time constants scaling as in balanced setups and under other regimes, linking SSM learning dynamics to established deep-network theory. These results provide a principled step toward a comprehensive theory of learning dynamics in deep state-space models and motivate extensions to multi-layer and nonlinear SSMs.

Abstract

State space models (SSMs) have shown remarkable empirical performance on many long sequence modeling tasks, but a theoretical understanding of these models is still lacking. In this work, we study the learning dynamics of linear SSMs to understand how covariance structure in data, latent state size, and initialization affect the evolution of parameters throughout learning with gradient descent. We show that focusing on the learning dynamics in the frequency domain affords analytical solutions under mild assumptions, and we establish a link between one-dimensional SSMs and the dynamics of deep linear feed-forward networks. Finally, we analyze how latent state over-parameterization affects convergence time and describe future work in extending our results to the study of deep SSMs with nonlinear connections. This work is a step toward a theory of learning dynamics in deep state space models.
Paper Structure (8 sections, 3 theorems, 28 equations, 1 figure)

This paper contains 8 sections, 3 theorems, 28 equations, 1 figure.

Key Result

proposition 1

Let $U_k \in \mathbb{C}$ and $Y_k \in \mathbb{C}$ for $k=1,\ldots,L$ denote the discrete Fourier transform (DFT) of the inputs $u_{1:T}$ and outputs $y_{1:T}$, respectively. For diagonal dynamics matrices $A=\mathrm{diag}(a_1,\ldots, a_N)$ with $|a_i|<1$ for all $i=1,\ldots,N$ to ensure stability, t

Figures (1)

  • Figure 1: Learning dynamics of SSMs in the frequency domain. A. A linear SSM defined in eq. \ref{['eq:ssm']} unrolled for a length $L$ sequence. B. Applying the discrete Fourier transform, the SSM is fully described by its frequency response $H_k$, transforming a recurrence in the time-domain to modulated scalar multiplication in the frequency domain. C. An example input signal in the time-domain. D. The discrete Fourier transform of the input signal in the frequency domain. E. Even under strong assumptions, the analytical learning dynamics from eq. \ref{['eq:invtimecourse']} approximate the empirical evolution of the SSM for simple input-output modes. Each subplot shows the evolution of the frequency response for individual input-output pairs. F. Extending the theory to $N$-dimensional one-layer SSMs, we show how over-parameterization in the latent state can lead to faster convergence. Full lines denote trajectories arising from automatic differentiation, dashed trajectories are obtained from numerical simulations of the analytical solution to the learning dynamics.

Theorems & Definitions (5)

  • proposition 1
  • proposition 2
  • proof : Proposition \ref{['prop:dft']}
  • proof : Proposition \ref{['prop:graddescent']}
  • proposition 3