Table of Contents
Fetching ...

Towards Learning High-Precision Least Squares Algorithms with Sequence Models

Jerry Liu, Jessica Grogan, Owen Dugan, Ashish Rao, Simran Arora, Atri Rudra, Christopher Ré

TL;DR

This work addresses whether sequence models can learn numerical algorithms for least squares with machine-precision accuracy. It exposes expressivity and optimization bottlenecks in standard Transformers, showing they struggle to reach near floating-point precision and to generalize numerically beyond training distributions. By leveraging polynomial architectures like BaseConv and a high-precision training recipe (adaptive LR based on gradient signal and EMA over updates), the authors demonstrate the ability to learn high-precision gradient-descent iterates, achieving MSEs as low as ~$10^{-13}$ for single iterates and ~$10^{-10}$ for multiple iterates, with significantly better out-of-distribution generalization than Transformers. While end-to-end learning of the full GD trajectory remains challenging, the results mark a substantial step toward learning numerical algorithms from data and highlight the tradeoffs between architecture expressivity and optimization dynamics in scientific ML tasks. The findings imply practical potential for precise numerics in scientific modeling, with BaseConv offering a scalable path to numerically robust algorithm learning for LS and related problems, and they point to future work on deeper optimization pipelines and broader PDE/ODE contexts.

Abstract

This paper investigates whether sequence models can learn to perform numerical algorithms, e.g. gradient descent, on the fundamental problem of least squares. Our goal is to inherit two properties of standard algorithms from numerical analysis: (1) machine precision, i.e. we want to obtain solutions that are accurate to near floating point error, and (2) numerical generality, i.e. we want them to apply broadly across problem instances. We find that prior approaches using Transformers fail to meet these criteria, and identify limitations present in existing architectures and training procedures. First, we show that softmax Transformers struggle to perform high-precision multiplications, which prevents them from precisely learning numerical algorithms. Second, we identify an alternate class of architectures, comprised entirely of polynomials, that can efficiently represent high-precision gradient descent iterates. Finally, we investigate precision bottlenecks during training and address them via a high-precision training recipe that reduces stochastic gradient noise. Our recipe enables us to train two polynomial architectures, gated convolutions and linear attention, to perform gradient descent iterates on least squares problems. For the first time, we demonstrate the ability to train to near machine precision. Applied iteratively, our models obtain 100,000x lower MSE than standard Transformers trained end-to-end and they incur a 10,000x smaller generalization gap on out-of-distribution problems. We make progress towards end-to-end learning of numerical algorithms for least squares.

Towards Learning High-Precision Least Squares Algorithms with Sequence Models

TL;DR

This work addresses whether sequence models can learn numerical algorithms for least squares with machine-precision accuracy. It exposes expressivity and optimization bottlenecks in standard Transformers, showing they struggle to reach near floating-point precision and to generalize numerically beyond training distributions. By leveraging polynomial architectures like BaseConv and a high-precision training recipe (adaptive LR based on gradient signal and EMA over updates), the authors demonstrate the ability to learn high-precision gradient-descent iterates, achieving MSEs as low as ~ for single iterates and ~ for multiple iterates, with significantly better out-of-distribution generalization than Transformers. While end-to-end learning of the full GD trajectory remains challenging, the results mark a substantial step toward learning numerical algorithms from data and highlight the tradeoffs between architecture expressivity and optimization dynamics in scientific ML tasks. The findings imply practical potential for precise numerics in scientific modeling, with BaseConv offering a scalable path to numerically robust algorithm learning for LS and related problems, and they point to future work on deeper optimization pipelines and broader PDE/ODE contexts.

Abstract

This paper investigates whether sequence models can learn to perform numerical algorithms, e.g. gradient descent, on the fundamental problem of least squares. Our goal is to inherit two properties of standard algorithms from numerical analysis: (1) machine precision, i.e. we want to obtain solutions that are accurate to near floating point error, and (2) numerical generality, i.e. we want them to apply broadly across problem instances. We find that prior approaches using Transformers fail to meet these criteria, and identify limitations present in existing architectures and training procedures. First, we show that softmax Transformers struggle to perform high-precision multiplications, which prevents them from precisely learning numerical algorithms. Second, we identify an alternate class of architectures, comprised entirely of polynomials, that can efficiently represent high-precision gradient descent iterates. Finally, we investigate precision bottlenecks during training and address them via a high-precision training recipe that reduces stochastic gradient noise. Our recipe enables us to train two polynomial architectures, gated convolutions and linear attention, to perform gradient descent iterates on least squares problems. For the first time, we demonstrate the ability to train to near machine precision. Applied iteratively, our models obtain 100,000x lower MSE than standard Transformers trained end-to-end and they incur a 10,000x smaller generalization gap on out-of-distribution problems. We make progress towards end-to-end learning of numerical algorithms for least squares.

Paper Structure

This paper contains 93 sections, 39 theorems, 292 equations, 18 figures, 7 tables, 4 algorithms.

Key Result

Theorem 3.1

One-layer single-headed (causal) softmax attention cannot exactly represent $\textsc{Square}$ and $\textsc{Multiply}$ for all possible inputs.

Figures (18)

  • Figure 1: Prior work focuses on statistical least squares: Transformers approximate Bayes-optimal estimators (left, adapted from garg2022can). In this work, we focus on numerical least squares: Transformers struggle to obtain precise solutions (inset). Using a high-precision training recipe, we train two polynomial architectures, BaseConv and linear attention, to perform high-precision gradient descent iterates on least squares (right): applied iteratively, they reach $\approx 10^{-13}$ MSE.
  • Figure 2: Transformers generalize poorly to out-of-distribution regression targets. In contrast, using our training recipe, we train a BaseConv model to perform high-precision GD iterates. Applied iteratively, our BaseConv model incurs $10,000 \times$ less generalization error on out-of-distribution target vectors than the Transformer.
  • Figure 3: Precision vs. Transformer depth, with and without LayerNorms (LN), on synthetic tasks. While shallow Transformers are able to learn the Read and Linear tasks to high precision ($< 10^{-8}$ with 2-layer models), precision on the Multiply task scales poorly with depth (only $10^{-6}$ with 8-layer models).
  • Figure 4: BaseConv can express high-precision gradient descent: our implementation of the weight construction reaches $10^{-13}$ MSE in practice.
  • Figure 5: Gradient metric is predictive of precision saturation (left). We propose a simple adaptive LR scheduler that alleviates precision saturation (middle). Adaptive LR effectively boosts gradient signal during training (right).
  • ...and 13 more figures

Theorems & Definitions (88)

  • Theorem 3.1: Informal statement of Theorem \ref{['thm:softmax_attn_square_app']} and Corollary \ref{['cor:softmax-attn-cant-multiply']}
  • Theorem 4.1: Informal statement of Theorems \ref{['thm: exactly-multiply']}, \ref{['thm: exactly-linear']}
  • Theorem 4.2: Informal statement of Theorem \ref{['prop:BC-approx-univar-func']}
  • Definition D.1
  • Definition D.2
  • Definition D.3
  • Definition D.4
  • Definition D.5
  • Definition D.6
  • Definition D.7
  • ...and 78 more