Transformers Provably Learn Sparse Token Selection While Fully-Connected Nets Cannot
Zixuan Wang, Stanley Wei, Daniel Hsu, Jason D. Lee
TL;DR
The paper tackles the problem of learnability for the sparse token selection task $ ext{STS}_q$ by analyzing gradient-descent dynamics of a one-layer transformer with stochastic positional encoding (SPE). It proves that, with width $m= ilde{O}(d+q ext{log}T)$, GD globally converges to a solution that solves $ ext{STS}_q$, while any FCN requires width $ ilde{oldsymbol{ m O}}(Td)$ to achieve comparable performance, establishing an exponential separation in expressive power. It also proves a length generalization guarantee under SPE, and provides empirical evidence showing convergence to ground-truth directions and superior out-of-distribution generalization compared with fixed positional encodings. These results collectively justify the inductive bias of transformers for this class of arithmetic tasks and offer insight into the benefits of randomized positional encodings for robust length generalization.
Abstract
The transformer architecture has prevailed in various deep learning settings due to its exceptional capabilities to select and compose structural information. Motivated by these capabilities, Sanford et al. proposed the sparse token selection task, in which transformers excel while fully-connected networks (FCNs) fail in the worst case. Building upon that, we strengthen the FCN lower bound to an average-case setting and establish an algorithmic separation of transformers over FCNs. Specifically, a one-layer transformer trained with gradient descent provably learns the sparse token selection task and, surprisingly, exhibits strong out-of-distribution length generalization. We provide empirical simulations to justify our theoretical findings.
