Approximating Matrix Functions with Deep Neural Networks and Transformers
Rahul Padmanabhan, Simone Brugiapaglia
TL;DR
The paper tackles the problem of learning matrix functions, such as $e^A$ and $\operatorname{sign}(A)$, with neural networks. It combines a theoretical result showing a ReLU DNN can approximate $e^A$ over $[-M,M]^{n\times n}$ with width exponential in $nM$ and depth roughly linear in $nM$, with a practical study demonstrating that a transformer encoder–decoder using numerical encodings can achieve high accuracy for certain matrix functions on small matrices ($3\times3$ to $5\times5$). The numerical results reveal a strong dependence on encoding schemes, with the FP15 encoding excelling for the sign function and B1999 performing best for the exponential, while sine and cosine remain challenging. Overall, the work highlights the potential and limitations of Transformer-based surrogates for matrix-function computations in scientific computing and points to encoding design as a crucial lever for performance.
Abstract
Transformers have revolutionized natural language processing, but their use for numerical computation has received less attention. We study the approximation of matrix functions, which map scalar functions to matrices, using neural networks including transformers. We focus on functions mapping square matrices to square matrices of the same dimension. These types of matrix functions appear throughout scientific computing, e.g., the matrix exponential in continuous-time Markov chains and the matrix sign function in stability analysis of dynamical systems. In this paper, we make two contributions. First, we prove bounds on the width and depth of ReLU networks needed to approximate the matrix exponential to an arbitrary precision. Second, we show experimentally that a transformer encoder-decoder with suitable numerical encodings can approximate certain matrix functions at a relative error of 5% with high probability. Our study reveals that the encoding scheme strongly affects performance, with different schemes working better for different functions.
