Table of Contents
Fetching ...

Pattern Matching in AI Compilers and its Formalization (Extended Version)

Joseph W. Cutler, Alex Collins, Bin Fan, Mahesh Ravishankar, Vinod Grover

TL;DR

The paper addresses the challenge of pattern-based optimization in AI compilers by introducing PyPM, a Python-embedded DSL for expressing recursive, nonlinear subgraph patterns and their rewrite rules. It formalizes PyPM via CorePyPM, providing both declarative and algorithmic semantics and a mechanized Coq proof of equivalence, thereby distilling a large, complex C++ matcher into a sound mathematical core. The authors demonstrate the approach with practical PyPM features (alternates, recursion, function variables, guards, and local variables) and show how PyPM enables effective hand-crafted optimizations and a just-in-time directed graph partitioning workflow within the DLCB backend. The work improves confidence in rewrite-based optimization for AI compilers, and provides a portable, expressive framework for advancing high-performance GPU kernel targeting across modern hardware.

Abstract

PyPM is a Python-based domain specific language (DSL) for building rewrite-based optimization passes on machine learning computation graphs. Users define individual optimizations by writing (a) patterns that match subgraphs of a computation graph and (b) corresponding rules which replace a matched subgraph with an optimized kernel. PyPM is distinguished from the many other DSLs for defining rewriting passes by its complex and novel pattern language which borrows concepts from logic programming. PyPM patterns can be recursive, nondeterminstic, and can require checking domain-specific constraints such as the shapes of tensors. The PyPM implementation is thus similarly complicated, consisting of thousands of lines of C++ code. In this paper, we present our work on building PyPM, as well as formalizing and distilling and this complexity to an understandable mathematical core. We have developed a formal core calculus expressing the main operations of the PyPM pattern language. We define both a declarative semantics - describing which patterns match which terms - and an algorithmic semantics - an idealized version of the PyPM pattern interpreter - and prove their equivalence. The development is fully mechanized in the Coq proof assistant.

Pattern Matching in AI Compilers and its Formalization (Extended Version)

TL;DR

The paper addresses the challenge of pattern-based optimization in AI compilers by introducing PyPM, a Python-embedded DSL for expressing recursive, nonlinear subgraph patterns and their rewrite rules. It formalizes PyPM via CorePyPM, providing both declarative and algorithmic semantics and a mechanized Coq proof of equivalence, thereby distilling a large, complex C++ matcher into a sound mathematical core. The authors demonstrate the approach with practical PyPM features (alternates, recursion, function variables, guards, and local variables) and show how PyPM enables effective hand-crafted optimizations and a just-in-time directed graph partitioning workflow within the DLCB backend. The work improves confidence in rewrite-based optimization for AI compilers, and provides a portable, expressive framework for advancing high-performance GPU kernel targeting across modern hardware.

Abstract

PyPM is a Python-based domain specific language (DSL) for building rewrite-based optimization passes on machine learning computation graphs. Users define individual optimizations by writing (a) patterns that match subgraphs of a computation graph and (b) corresponding rules which replace a matched subgraph with an optimized kernel. PyPM is distinguished from the many other DSLs for defining rewriting passes by its complex and novel pattern language which borrows concepts from logic programming. PyPM patterns can be recursive, nondeterminstic, and can require checking domain-specific constraints such as the shapes of tensors. The PyPM implementation is thus similarly complicated, consisting of thousands of lines of C++ code. In this paper, we present our work on building PyPM, as well as formalizing and distilling and this complexity to an understandable mathematical core. We have developed a formal core calculus expressing the main operations of the PyPM pattern language. We define both a declarative semantics - describing which patterns match which terms - and an algorithmic semantics - an idealized version of the PyPM pattern interpreter - and prove their equivalence. The development is fully mechanized in the Coq proof assistant.

Paper Structure

This paper contains 19 sections, 2 theorems, 9 equations, 18 figures.

Key Result

Theorem 1

If $p \, @ \, \theta \approx t$ and $\theta \subseteq \theta'$ then $p \, @ \, \theta' \approx t$

Figures (18)

  • Figure 1: cuBLAS Pattern Example
  • Figure 2: Alternate GELU Pattern
  • Figure 3: Recursive Unary Function Pattern
  • Figure 4: Recursive Pattern with Local Variables and Match Constraints
  • Figure 5: Grammar of Terms and Basic Patterns
  • ...and 13 more figures

Theorems & Definitions (3)

  • Theorem 1: Match Weakening
  • Theorem 2: Algorithmic Soundness
  • proof