Automatic Differentiation for ML-family languages: correctness via logical relations
Fernando Lucatelli Nunes, Matthijs Vákár
TL;DR
This work presents a principled, denotational approach to proving the correctness of automatic differentiation (AD) for ML-family languages that feature term and type recursion, partial differentiable primitives, and higher-order constructs. The authors introduce a dual-numbers AD macro $\mathcal{D}$ and a semantic framework based on $\omega$-Cpo categories, enriched LR, and a subscone/scone infrastructure to reason about differentiation through recursion and polymorphism. They establish forward and reverse AD correctness, show AD can be extended to recursive types and arrays, and develop almost-everywhere differentiability results by incorporating piecewise analytic partiality via a lifting of the partiality monad in the LR setting. The resulting framework provides a modular, semantic justification for AD in expressive ML-family languages and informs future compiler implementations and optimizations (including CHAD-style enhancements) with a solid theoretical foundation.
Abstract
We give a simple, direct and reusable logical relations technique for languages with term and type recursion and partially defined differentiable functions. We demonstrate it by working out the case of Automatic Differentiation (AD) correctness: namely, we present a correctness proof of a dual numbers style AD code transformation for realistic functional languages in the ML-family. We also show how this code transformation provides us with correct forward- and reverse-mode AD. The starting point is to interpret a functional programming language as a suitable freely generated categorical structure. In this setting, by the universal property of the syntactic categorical structure, the dual numbers AD code transformation and the basic $ω$-cpo semantics arise as structure preserving functors. The proof follows, then, by a novel logical relations argument. The key to much of our contribution is a powerful monadic logical relations technique for term recursion and recursive types. It provides us with a semantic correctness proof based on a simple approach for denotational semantics, making use only of the very basic concrete model of $ω$-cpos.
