Learning Bug Context for PyTorch-to-JAX Translation with LLMs
Hung Phan, Son Le Vu, Ali Jannesari
TL;DR
This paper tackles the challenge of translating PyTorch code to JAX, a task hindered by core design and ecosystem gaps and limited parallel corpora. It introduces T2J, a prompt-augmentation framework that creates a fixed-bug JAX dataset via human repair of low-cost LLM translations and leverages this dataset to guide cheap LLMs through augmented prompts. The authors also propose new evaluation metrics (CodeTrans_Score, FixCost_Score, and Comparison_Score) and demonstrate that T2J can achieve up to 10% CodeBLEU improvements, 50% FixCost improvements, and 2.5x speedups, with 100% superiority in the Rule-based Comparison metric on intrinsic tests. Collectively, the work shows that domain-specific prompting and human-in-the-loop fixes can substantially improve cross-library translations, enabling more reliable PyTorch-to-JAX migrations while reducing computation time.
Abstract
Despite recent progress of large language models (LLMs) on code translation among mainstream languages, translating PyTorch to JAX remains nontrivial. The two libraries, though both embedded in Python, differ in core design, execution semantics, and ecosystem maturity; JAX is newer and comparatively underrepresented in public code, and parallel PyTorch--JAX corpora are limited. Weaknesses in existing evaluation further complicate cross-framework benchmarking. We present T2J, a prompt-augmentation framework that strengthens LLM-based PyTorch to JAX translation. Our pipeline (i) assembles two PyTorch sources -- the problem-solving set from TorchLeet (Aroori & Chien, 2025) and a GitHub-derived set from CodeParrot (Wolf et al., 2022) -- and uses GPT-4o-mini to produce initial JAX drafts; (ii) engages two professional developers to iteratively repair those drafts until functional equivalence, yielding a curated fixed-bug dataset of common errors and patches; and (iii) constructs augmented prompts that inject structured guidance from these fixes to steer lightweight LLMs (e.g., GPT-4o-mini). We also introduce three metrics tailored to PyTorch to JAX: T2J CodeTrans Score, T2J FixCost Score (an LLM-based estimate of bug-fix effort), and T2J Comparison Score (LLM-as-judge). Empirically, T2J raises GPT-4o-mini performance by up to 10% on CodeBLEU, 50% on T2J FixCost Score, 1.33 points on T2J CodeTrans Score (0--4 scale), and 100% on T2J Comparison Score; moreover, the generated code runs up to 2.5x faster than the baseline.
