Towards Understanding Transformers in Learning Random Walks
Wei Shi, Yuan Cao
TL;DR
This work theoretically and empirically analyzes a one-layer transformer trained by gradient descent on the task of predicting the next state in random walks on circles. It demonstrates that for 0<p<1, the model attains optimal prediction accuracy and, importantly, is interpretable: softmax attention nearly always selects the direct parent token and the value matrix implements the one-step transition. The paper also proves that deterministic walks (p=0 or 1) are failure cases under zero initialization, and supports these findings with comprehensive experiments, including extensions to simple QA tasks. Together, these results provide sharp insight into the learnability and interpretability of transformers for classical stochastic processes and point to how initialization influences optimization in even simple tasks.
Abstract
Transformers have proven highly effective across various applications, especially in handling sequential data such as natural languages and time series. However, transformer models often lack clear interpretability, and the success of transformers has not been well understood in theory. In this paper, we study the capability and interpretability of transformers in learning a family of classic statistical models, namely random walks on circles. We theoretically demonstrate that, after training with gradient descent, a one-layer transformer model can achieve optimal accuracy in predicting random walks. Importantly, our analysis reveals that the trained model is interpretable: the trained softmax attention serves as a token selector, focusing on the direct parent state; subsequently, the value matrix executes a one-step probability transition to predict the location of the next state based on this parent state. We also show that certain edge cases not covered by our theory are indeed failure cases, demonstrating that our theoretical conditions are tight. By investigating these success and failure cases, it is revealed that gradient descent with small initialization may fail or struggle to converge to a good solution in certain simple tasks even beyond random walks. Experiments are conducted to support our theoretical findings.
