State-space models can learn in-context by gradient descent
Neeraj Mohan Sushma, Yudou Tian, Harshvardhan Mestha, Nicolo Colombo, David Kappel, Anand Subramoney
TL;DR
The paper presents an explicit construction showing that one-layer gated state-space models (SSMs) can perform minibatch gradient descent on an implicit linear regression loss, enabling in-context learning that parallels transformer-based ICL. It extends the construction to multi-step gradient descent and to non-linear regression via an MLP, and demonstrates empirically that randomly initialized GD-SSMs trained on linear and non-linear regression tasks arrive at parameters matching the theoretical design. The work highlights the crucial role of input- and output-gating, along with a sliding-window input scheme, as the inductive biases enabling expressive ICL in SSMs and connects the approach to linear self-attention. Overall, it provides a mechanistic explanation for ICL in SSMs and offers a parsimonious, scalable alternative to transformer-based ICL with potential practical impact for sequence modeling.
Abstract
Deep state-space models (Deep SSMs) are becoming popular as effective approaches to model sequence data. They have also been shown to be capable of in-context learning, much like transformers. However, a complete picture of how SSMs might be able to do in-context learning has been missing. In this study, we provide a direct and explicit construction to show that state-space models can perform gradient-based learning and use it for in-context learning in much the same way as transformers. Specifically, we prove that a single structured state-space model layer, augmented with multiplicative input and output gating, can reproduce the outputs of an implicit linear model with least squares loss after one step of gradient descent. We then show a straightforward extension to multi-step linear and non-linear regression tasks. We validate our construction by training randomly initialized augmented SSMs on linear and non-linear regression tasks. The empirically obtained parameters through optimization match the ones predicted analytically by the theoretical construction. Overall, we elucidate the role of input- and output-gating in recurrent architectures as the key inductive biases for enabling the expressive power typical of foundation models. We also provide novel insights into the relationship between state-space models and linear self-attention, and their ability to learn in-context.
