RAM: Replace Attention with MLP for Efficient Multivariate Time Series Forecasting
Suhan Guo, Jiahong Deng, Yi Wei, Hui Dou, Furao Shen, Jian Zhao
TL;DR
The paper addresses the high computational cost of attention in multivariate time series forecasting by proposing RAM, a method that replaces attention with an MLP-based structure. RAM demonstrates that Q, K, V projections and attention mappings can be pruned without substantial loss, reducing FLOPs substantially while maintaining competitive accuracy in both spatio-temporal and long-term forecasting tasks. It introduces an abstract AMTSFM framework and shows that the encoder/decoder attention modules are not equally critical, with feedforward and residual components driving the MLP’s performance. The approach has practical implications for deploying efficient forecasting models on resource-constrained devices and prompts broader questions about the necessity of attention in time-series models across domains.
Abstract
Attention-based architectures have become ubiquitous in time series forecasting tasks, including spatio-temporal (STF) and long-term time series forecasting (LTSF). Yet, our understanding of the reasons for their effectiveness remains limited. In this work, we propose a novel pruning strategy, $\textbf{R}$eplace $\textbf{A}$ttention with $\textbf{M}$LP (RAM), that approximates the attention mechanism using only feedforward layers, residual connections, and layer normalization for temporal and/or spatial modeling in multivariate time series forecasting. Specifically, the Q, K, and V projections, the attention score calculation, the dot-product between the attention score and the V, and the final projection can be removed from the attention-based networks without significantly degrading the performance, so that the given network remains the top-tier compared to other SOTA methods. RAM achieves a $62.579\%$ reduction in FLOPs for spatio-temporal models with less than $2.5\%$ performance drop, and a $42.233\%$ FLOPs reduction for LTSF models with less than $2\%$ performance drop.
