Table of Contents
Fetching ...

State-space models can learn in-context by gradient descent

Neeraj Mohan Sushma, Yudou Tian, Harshvardhan Mestha, Nicolo Colombo, David Kappel, Anand Subramoney

TL;DR

The paper presents an explicit construction showing that one-layer gated state-space models (SSMs) can perform minibatch gradient descent on an implicit linear regression loss, enabling in-context learning that parallels transformer-based ICL. It extends the construction to multi-step gradient descent and to non-linear regression via an MLP, and demonstrates empirically that randomly initialized GD-SSMs trained on linear and non-linear regression tasks arrive at parameters matching the theoretical design. The work highlights the crucial role of input- and output-gating, along with a sliding-window input scheme, as the inductive biases enabling expressive ICL in SSMs and connects the approach to linear self-attention. Overall, it provides a mechanistic explanation for ICL in SSMs and offers a parsimonious, scalable alternative to transformer-based ICL with potential practical impact for sequence modeling.

Abstract

Deep state-space models (Deep SSMs) are becoming popular as effective approaches to model sequence data. They have also been shown to be capable of in-context learning, much like transformers. However, a complete picture of how SSMs might be able to do in-context learning has been missing. In this study, we provide a direct and explicit construction to show that state-space models can perform gradient-based learning and use it for in-context learning in much the same way as transformers. Specifically, we prove that a single structured state-space model layer, augmented with multiplicative input and output gating, can reproduce the outputs of an implicit linear model with least squares loss after one step of gradient descent. We then show a straightforward extension to multi-step linear and non-linear regression tasks. We validate our construction by training randomly initialized augmented SSMs on linear and non-linear regression tasks. The empirically obtained parameters through optimization match the ones predicted analytically by the theoretical construction. Overall, we elucidate the role of input- and output-gating in recurrent architectures as the key inductive biases for enabling the expressive power typical of foundation models. We also provide novel insights into the relationship between state-space models and linear self-attention, and their ability to learn in-context.

State-space models can learn in-context by gradient descent

TL;DR

The paper presents an explicit construction showing that one-layer gated state-space models (SSMs) can perform minibatch gradient descent on an implicit linear regression loss, enabling in-context learning that parallels transformer-based ICL. It extends the construction to multi-step gradient descent and to non-linear regression via an MLP, and demonstrates empirically that randomly initialized GD-SSMs trained on linear and non-linear regression tasks arrive at parameters matching the theoretical design. The work highlights the crucial role of input- and output-gating, along with a sliding-window input scheme, as the inductive biases enabling expressive ICL in SSMs and connects the approach to linear self-attention. Overall, it provides a mechanistic explanation for ICL in SSMs and offers a parsimonious, scalable alternative to transformer-based ICL with potential practical impact for sequence modeling.

Abstract

Deep state-space models (Deep SSMs) are becoming popular as effective approaches to model sequence data. They have also been shown to be capable of in-context learning, much like transformers. However, a complete picture of how SSMs might be able to do in-context learning has been missing. In this study, we provide a direct and explicit construction to show that state-space models can perform gradient-based learning and use it for in-context learning in much the same way as transformers. Specifically, we prove that a single structured state-space model layer, augmented with multiplicative input and output gating, can reproduce the outputs of an implicit linear model with least squares loss after one step of gradient descent. We then show a straightforward extension to multi-step linear and non-linear regression tasks. We validate our construction by training randomly initialized augmented SSMs on linear and non-linear regression tasks. The empirically obtained parameters through optimization match the ones predicted analytically by the theoretical construction. Overall, we elucidate the role of input- and output-gating in recurrent architectures as the key inductive biases for enabling the expressive power typical of foundation models. We also provide novel insights into the relationship between state-space models and linear self-attention, and their ability to learn in-context.

Paper Structure

This paper contains 33 sections, 4 theorems, 49 equations, 6 figures, 4 tables.

Key Result

Proposition 4.1

Given a diagonal linear recurrent layer, and tokens ${\bm{s}}_j = {\bm{c}}_j = [{\bm{x}}_j \, y_j, {\bm{x}}_{j+1}]$, for $j = 1, \ldots, N$, and $[\ldots]$ concatenation, ${\bm{x}}_j, y_j$ drawn from a linear model, one can construct recurrent matrix ${\bm{A}}_j({\bm{S}}')$, input ${\bm{B}}_j({\bm{S

Figures (6)

  • Figure 1: Comparing one step of GD with a trained single layer GD-SSM for 1-dimensional regression: A: Trained single layer GD-SSM loss and GD-SSM loss with the parameters from our construction are identical. B: Cosine similarity and the L2 distance between models as well as their predictions. C: Comparison of loss between Gradient Descent (GD) and the GD-SSM layer model for different input sizes $f$. D: The trained single 1-D GD-SSM layer, and gradient descent show identical loss (in log-scale) when provided input data different than during training i.e. with scale of 1. We display the mean/std. or the single runs of 5 seeds.
  • Figure 2: Comparing one step of GD with a trained single layer GD-SSM for N-dimensional regression: A: Trained single layer GD-SSM loss and GD-SSM loss with the parameters from our construction are identical. B: Cosine similarity and the L2 distance between models as well as their predictions. C: Comparison of loss between Gradient Descent (GD) and the GD-SSM layer model for different input sizes $f$. D: The trained GD-SSM layer, and gradient descent show identical loss (in log-scale) when provided input data different than during training i.e. with scale of 1. We display the mean/std. or the single runs of 5 seeds.
  • Figure 3: Comparison of learned weights between a trained model and our construction:A: Comparison of input-gating weights of trained GD-SSM with the GD-SSM parameters from our construction. B: Comparison of recurrent parameters of trained GD-SSM with the GD-SSM parameters from our construction. Since the recurrence parameters are tensors, for the ease of visualization, each diagonal entry is the mean of the corresponding diagonal recurrence matrix. C: Comparison of skip connection weights of trained GD-SSM with the GD-SSM parameters from our construction.
  • Figure 4: A: Comparison of performance for general regression tasks. B: Comparison of GD-SSM performance with and without MLP layers in single layer and multi-layer setup on non-linear regression task. C: Comparison with other models on 1-D linear regression. The GD-SSM model was evaluated with 1-layer configuration, and S5, Mamba(+LP), and Griffin(+LP), where LP indicates linear projection, were also included for comparison. Additionally, Linear SA denotes linear self-attention models, with both 1-layer and 2-layer variants tested to evaluate performance. A similar comparison for N-D linear regression is presented in Figure \ref{['appendix:fig:nd-comparison']}.
  • Figure 5: Visualisation of the trained parameters for two layer GD-SSM on linear regression task.
  • ...and 1 more figures

Theorems & Definitions (4)

  • Proposition 4.1
  • Proposition 4.2
  • Proposition A.1
  • Proposition A.2