Table of Contents
Fetching ...

Optimizing Automatic Differentiation with Deep Reinforcement Learning

Jamie Lohoff, Emre Neftci

TL;DR

This work addresses efficient Jacobian computation in automatic differentiation by reinterpreting Jacobian accumulation as cross-country elimination on a computational graph. It introduces VertexGame, a single-player reinforcement learning framework, and an AlphaZero-style agent that discovers new, tailored elimination orders that exactly compute Jacobians while reducing multiplications. The approach yields up to 33% fewer multiplications in tested tasks and translates these gains into real runtime improvements via Graphax, a Python-based JAX-backed AD interpreter. Collectively, the method demonstrates that learned elimination orders can outperform traditional heuristics and standard AD modes on a range of problems, withGraphax enabling practical deployment and further speedups on modern hardware.

Abstract

Computing Jacobians with automatic differentiation is ubiquitous in many scientific domains such as machine learning, computational fluid dynamics, robotics and finance. Even small savings in the number of computations or memory usage in Jacobian computations can already incur massive savings in energy consumption and runtime. While there exist many methods that allow for such savings, they generally trade computational efficiency for approximations of the exact Jacobian. In this paper, we present a novel method to optimize the number of necessary multiplications for Jacobian computation by leveraging deep reinforcement learning (RL) and a concept called cross-country elimination while still computing the exact Jacobian. Cross-country elimination is a framework for automatic differentiation that phrases Jacobian accumulation as ordered elimination of all vertices on the computational graph where every elimination incurs a certain computational cost. We formulate the search for the optimal elimination order that minimizes the number of necessary multiplications as a single player game which is played by an RL agent. We demonstrate that this method achieves up to 33% improvements over state-of-the-art methods on several relevant tasks taken from diverse domains. Furthermore, we show that these theoretical gains translate into actual runtime improvements by providing a cross-country elimination interpreter in JAX that can efficiently execute the obtained elimination orders.

Optimizing Automatic Differentiation with Deep Reinforcement Learning

TL;DR

This work addresses efficient Jacobian computation in automatic differentiation by reinterpreting Jacobian accumulation as cross-country elimination on a computational graph. It introduces VertexGame, a single-player reinforcement learning framework, and an AlphaZero-style agent that discovers new, tailored elimination orders that exactly compute Jacobians while reducing multiplications. The approach yields up to 33% fewer multiplications in tested tasks and translates these gains into real runtime improvements via Graphax, a Python-based JAX-backed AD interpreter. Collectively, the method demonstrates that learned elimination orders can outperform traditional heuristics and standard AD modes on a range of problems, withGraphax enabling practical deployment and further speedups on modern hardware.

Abstract

Computing Jacobians with automatic differentiation is ubiquitous in many scientific domains such as machine learning, computational fluid dynamics, robotics and finance. Even small savings in the number of computations or memory usage in Jacobian computations can already incur massive savings in energy consumption and runtime. While there exist many methods that allow for such savings, they generally trade computational efficiency for approximations of the exact Jacobian. In this paper, we present a novel method to optimize the number of necessary multiplications for Jacobian computation by leveraging deep reinforcement learning (RL) and a concept called cross-country elimination while still computing the exact Jacobian. Cross-country elimination is a framework for automatic differentiation that phrases Jacobian accumulation as ordered elimination of all vertices on the computational graph where every elimination incurs a certain computational cost. We formulate the search for the optimal elimination order that minimizes the number of necessary multiplications as a single player game which is played by an RL agent. We demonstrate that this method achieves up to 33% improvements over state-of-the-art methods on several relevant tasks taken from diverse domains. Furthermore, we show that these theoretical gains translate into actual runtime improvements by providing a cross-country elimination interpreter in JAX that can efficiently execute the obtained elimination orders.
Paper Structure (30 sections, 36 equations, 12 figures, 7 tables)

This paper contains 30 sections, 36 equations, 12 figures, 7 tables.

Figures (12)

  • Figure 1: Summary of the AlphaGrad pipeline. We trained a neural network to produce new Automatic Differentiation (AD) algorithms using Deep RL that can be used in JAX. The resulting algorithms significantly outperform the current state of the art.
  • Figure 2: Step-by-step description of cross-country elimination with the simple example function $f(x_1, x_2) = (\log \sin(x_1x_2), x_1x_2 - \sin(x_1x_2))^\top$. (\ref{['fig:CompGraph']}) Initial computational graph. (\ref{['fig:ExtendedCompGraph']}) The partial derivatives are added to the edges of the computational graph. The intermediate variables $v_1$ and $v_2$ are defined through $v_1=x_1x_2$ and $v_2=\sin v_1$. (\ref{['fig:VertexElimination']}) Elimination of vertex 2 associated with the $\sin$ operation. The dotted red lines represent the edges that are deleted. (\ref{['fig:BipartiteGraph']}) Final bipartite graph after both intermediate vertices have been eliminated. All remaining edges contain entries of the Jacobian.
  • Figure 3: (\ref{['fig:SparseVertexElimination']}) Graphax implements sparse vertex elimination to benefit from the advantages of cross country elimination. (\ref{['fig:ComputationalGraphRepr']}) Sketch of the three-dimensional adjacency tensor that represents the computational graph. The colored surfaces represent the five different values encoded in the third dimension. The red and blue surfaces together contain the shape of the Jacobians while the green surface encodes their sparsity. The vertical dotted slices represent the input connectivity of a single vertex. In this work, we compress and feed the vertical slices as tokens into the transformer backbone such that we build a sequence running in direction of the black arrow.
  • Figure 4: Runtime measurements over 1000 trials for the vectorized RoeFlux_3d and MLP tasks with different batch sizes using the same setup as in table \ref{['tab:GraphaxResults']}. The MLP network sizes were scaled up with growing batch size by a constant factor. The exact procedure of scaling is explained in appendix \ref{['appendix:JAXComparison']}. Error bars are the 2.5- and 97.5-percentiles of the runtimes.
  • Figure 5: Runtime measurements over 100 trials for the scalar tasks. Error bars are the 2.5- and 97.5-percentiles of the runtimes.
  • ...and 7 more figures

Theorems & Definitions (1)

  • Definition 1