Table of Contents
Fetching ...

Optimizing Tensor Computation Graphs with Equality Saturation and Monte Carlo Tree Search

Jakob Hartmann, Guoliang He, Eiko Yoneki

TL;DR

This paper presents a tensor graph rewriting approach that uses Monte Carlo tree search to build superior IRs by identifying the most promising rewrite rules and introduces a novel extraction algorithm that can provide fast and accurate runtime estimates of tensor programs represented in an IR.

Abstract

The real-world effectiveness of deep neural networks often depends on their latency, thereby necessitating optimization techniques that can reduce a model's inference time while preserving its performance. One popular approach is to sequentially rewrite the input computation graph into an equivalent but faster one by replacing individual subgraphs. This approach gives rise to the so-called phase-ordering problem in which the application of one rewrite rule can eliminate the possibility to apply an even better one later on. Recent work has shown that equality saturation, a technique from compiler optimization, can mitigate this issue by first building an intermediate representation (IR) that efficiently stores multiple optimized versions of the input program before extracting the best solution in a second step. In practice, however, memory constraints prevent the IR from capturing all optimized versions and thus reintroduce the phase-ordering problem in the construction phase. In this paper, we present a tensor graph rewriting approach that uses Monte Carlo tree search to build superior IRs by identifying the most promising rewrite rules. We also introduce a novel extraction algorithm that can provide fast and accurate runtime estimates of tensor programs represented in an IR. Our approach improves the inference speedup of neural networks by up to 11% compared to existing methods.

Optimizing Tensor Computation Graphs with Equality Saturation and Monte Carlo Tree Search

TL;DR

This paper presents a tensor graph rewriting approach that uses Monte Carlo tree search to build superior IRs by identifying the most promising rewrite rules and introduces a novel extraction algorithm that can provide fast and accurate runtime estimates of tensor programs represented in an IR.

Abstract

The real-world effectiveness of deep neural networks often depends on their latency, thereby necessitating optimization techniques that can reduce a model's inference time while preserving its performance. One popular approach is to sequentially rewrite the input computation graph into an equivalent but faster one by replacing individual subgraphs. This approach gives rise to the so-called phase-ordering problem in which the application of one rewrite rule can eliminate the possibility to apply an even better one later on. Recent work has shown that equality saturation, a technique from compiler optimization, can mitigate this issue by first building an intermediate representation (IR) that efficiently stores multiple optimized versions of the input program before extracting the best solution in a second step. In practice, however, memory constraints prevent the IR from capturing all optimized versions and thus reintroduce the phase-ordering problem in the construction phase. In this paper, we present a tensor graph rewriting approach that uses Monte Carlo tree search to build superior IRs by identifying the most promising rewrite rules. We also introduce a novel extraction algorithm that can provide fast and accurate runtime estimates of tensor programs represented in an IR. Our approach improves the inference speedup of neural networks by up to 11% compared to existing methods.
Paper Structure (50 sections, 3 equations, 14 figures, 5 tables, 2 algorithms)

This paper contains 50 sections, 3 equations, 14 figures, 5 tables, 2 algorithms.

Figures (14)

  • Figure 1: Example e-graphs for expression $(a * 2) / 2$. E-classes are represented as rectangles, e-nodes are shown in circles.
  • Figure 2: Simple example of the phase-ordering problem during e-graph construction. The input expression is $a * 2 / 2$, the node limit is set to 10, and the cost is calculated based on the Abstract Syntax Tree size of the extracted expression. The x-axis shows the rewrite rules being applied and the y-axis displays the associated e-graph size together with the optimal cost at each point.
  • Figure 3: Overview of our tensor program optimizer using equality saturation and MCTS
  • Figure 4: Simple example of a neural network in which greedy extractors with existing cost functions overestimate the true graph runtime. The convolution operation marked in red is a common subexpression and thus counted twice, once by the add operation and once by the second convolution operation.
  • Figure 5: Speedup comparison on an NVIDIA A100 between different main and final extraction methods based on the original and optimized graph runtimes averaged across all runs and models. DCF = default cost function from egg, OCF = our cost function.
  • ...and 9 more figures