Table of Contents
Fetching ...

GraphMend: Code Transformations for Fixing Graph Breaks in PyTorch 2

Savini Kashmira, Jayanaka Dantanarayana, Thamirawaran Sathiyalogeswaran, Yichao Yuan, Nishil Talati, Krisztian Flautner, Lingjia Tang, Jason Mars

TL;DR

GraphMend tackles FX graph breaks in PyTorch 2 caused by dynamic control flow and Python side effects by introducing two high-level AST/CFG-based transformations within the Jaseci-based Jac pipeline. By performing Dynamo entry-point analysis, graph-break tagging, and AST rewrites (Predicated Dynamic Control Flow and Graph Epilogue Deferred Side Effects) before bytecode generation, GraphMend keeps more computation inside a single FX graph. Evaluation on eight Hugging Face models shows GraphMend eliminates most fixable breaks and delivers substantial cold-start and steady-state latency reductions, along with throughput gains, while exposing more opportunities for kernel fusion and better scheduling. The approach reduces the need for developer refactoring and complements PyTorch’s dynamic JIT, providing practical performance benefits on contemporary GPUs.

Abstract

This paper presents GRAPHMEND, a high-level compiler technique that eliminates FX graph breaks in PyTorch 2 programs. Although PyTorch 2 introduced TorchDynamo and TorchInductor to enable just-in-time graph compilation, unresolved dynamic control flow and unsupported Python constructs often fragment models into multiple FX graphs. These fragments force frequent fallbacks to eager mode, introduce costly CPU-to-GPU synchronizations, and reduce optimization opportunities. GRAPHMEND addresses this limitation by analyzing and transforming source code before execution. Built on the Jaseci compilation framework, GRAPHMEND introduces two code transformations that remove graph breaks due to dynamic control flow and Python side effects. This design allows PyTorch's compilation pipeline to capture larger, uninterrupted FX graphs without requiring manual refactoring by developers. Evaluation across eight Hugging Face models shows that GRAPHMEND removes graph breaks due to dynamic control flow and Python side effects, reducing the break count to 0 in 6 models and reducing it from 5 to 2 in another model. On NVIDIA RTX 3090 and A40 GPUs, GRAPHMEND achieves up to 75% latency reductions and up to 8% higher end-to-end throughput. These results demonstrate that high-level code transformation is an effective complement to PyTorch's dynamic JIT compilation pipeline, substantially improving both usability and performance.

GraphMend: Code Transformations for Fixing Graph Breaks in PyTorch 2

TL;DR

GraphMend tackles FX graph breaks in PyTorch 2 caused by dynamic control flow and Python side effects by introducing two high-level AST/CFG-based transformations within the Jaseci-based Jac pipeline. By performing Dynamo entry-point analysis, graph-break tagging, and AST rewrites (Predicated Dynamic Control Flow and Graph Epilogue Deferred Side Effects) before bytecode generation, GraphMend keeps more computation inside a single FX graph. Evaluation on eight Hugging Face models shows GraphMend eliminates most fixable breaks and delivers substantial cold-start and steady-state latency reductions, along with throughput gains, while exposing more opportunities for kernel fusion and better scheduling. The approach reduces the need for developer refactoring and complements PyTorch’s dynamic JIT, providing practical performance benefits on contemporary GPUs.

Abstract

This paper presents GRAPHMEND, a high-level compiler technique that eliminates FX graph breaks in PyTorch 2 programs. Although PyTorch 2 introduced TorchDynamo and TorchInductor to enable just-in-time graph compilation, unresolved dynamic control flow and unsupported Python constructs often fragment models into multiple FX graphs. These fragments force frequent fallbacks to eager mode, introduce costly CPU-to-GPU synchronizations, and reduce optimization opportunities. GRAPHMEND addresses this limitation by analyzing and transforming source code before execution. Built on the Jaseci compilation framework, GRAPHMEND introduces two code transformations that remove graph breaks due to dynamic control flow and Python side effects. This design allows PyTorch's compilation pipeline to capture larger, uninterrupted FX graphs without requiring manual refactoring by developers. Evaluation across eight Hugging Face models shows that GRAPHMEND removes graph breaks due to dynamic control flow and Python side effects, reducing the break count to 0 in 6 models and reducing it from 5 to 2 in another model. On NVIDIA RTX 3090 and A40 GPUs, GRAPHMEND achieves up to 75% latency reductions and up to 8% higher end-to-end throughput. These results demonstrate that high-level code transformation is an effective complement to PyTorch's dynamic JIT compilation pipeline, substantially improving both usability and performance.

Paper Structure

This paper contains 33 sections, 12 figures, 3 tables, 1 algorithm.

Figures (12)

  • Figure 1: A PyTorch forward pass with data-dependent control flow. The conditional branch on x.sum cannot be captured, causing a graph break. (See also Figure \ref{['fig:where_fix']} for the transformed version that eliminates the break.)
  • Figure 2: Eliminating the graph break by rewriting data-dependent control flow with torch.where, keeping the computation in a single FX graph. Together with Figure \ref{['fig:graph_break_example']}, these illustrate how dataflow rewriting replaces dynamic Python branching with predicated tensor execution.
  • Figure 3: Profiled traces of forward pass execution across CPU and GPU. (a) Forward pass execution trace of code with graph breaks in Figure \ref{['fig:graph_break_example']}. (b) Forward pass execution trace of equivalent code with graph breaks fixed in Figure \ref{['fig:where_fix']}.
  • Figure 4: (a) Existing graph-breaks in a real model and (b) how it can be fixed to produce a single contiguous FX graph.
  • Figure 5: Compiled function with a print statement. Dynamo inserts a graph break at the print statement. (See also Figure \ref{['fig:stat_fix']} for the transformed version that avoids the break.)
  • ...and 7 more figures