Table of Contents
Fetching ...

Automatic Differentiation for ML-family languages: correctness via logical relations

Fernando Lucatelli Nunes, Matthijs Vákár

TL;DR

This work presents a principled, denotational approach to proving the correctness of automatic differentiation (AD) for ML-family languages that feature term and type recursion, partial differentiable primitives, and higher-order constructs. The authors introduce a dual-numbers AD macro $\mathcal{D}$ and a semantic framework based on $\omega$-Cpo categories, enriched LR, and a subscone/scone infrastructure to reason about differentiation through recursion and polymorphism. They establish forward and reverse AD correctness, show AD can be extended to recursive types and arrays, and develop almost-everywhere differentiability results by incorporating piecewise analytic partiality via a lifting of the partiality monad in the LR setting. The resulting framework provides a modular, semantic justification for AD in expressive ML-family languages and informs future compiler implementations and optimizations (including CHAD-style enhancements) with a solid theoretical foundation.

Abstract

We give a simple, direct and reusable logical relations technique for languages with term and type recursion and partially defined differentiable functions. We demonstrate it by working out the case of Automatic Differentiation (AD) correctness: namely, we present a correctness proof of a dual numbers style AD code transformation for realistic functional languages in the ML-family. We also show how this code transformation provides us with correct forward- and reverse-mode AD. The starting point is to interpret a functional programming language as a suitable freely generated categorical structure. In this setting, by the universal property of the syntactic categorical structure, the dual numbers AD code transformation and the basic $ω$-cpo semantics arise as structure preserving functors. The proof follows, then, by a novel logical relations argument. The key to much of our contribution is a powerful monadic logical relations technique for term recursion and recursive types. It provides us with a semantic correctness proof based on a simple approach for denotational semantics, making use only of the very basic concrete model of $ω$-cpos.

Automatic Differentiation for ML-family languages: correctness via logical relations

TL;DR

This work presents a principled, denotational approach to proving the correctness of automatic differentiation (AD) for ML-family languages that feature term and type recursion, partial differentiable primitives, and higher-order constructs. The authors introduce a dual-numbers AD macro and a semantic framework based on -Cpo categories, enriched LR, and a subscone/scone infrastructure to reason about differentiation through recursion and polymorphism. They establish forward and reverse AD correctness, show AD can be extended to recursive types and arrays, and develop almost-everywhere differentiability results by incorporating piecewise analytic partiality via a lifting of the partiality monad in the LR setting. The resulting framework provides a modular, semantic justification for AD in expressive ML-family languages and informs future compiler implementations and optimizations (including CHAD-style enhancements) with a solid theoretical foundation.

Abstract

We give a simple, direct and reusable logical relations technique for languages with term and type recursion and partially defined differentiable functions. We demonstrate it by working out the case of Automatic Differentiation (AD) correctness: namely, we present a correctness proof of a dual numbers style AD code transformation for realistic functional languages in the ML-family. We also show how this code transformation provides us with correct forward- and reverse-mode AD. The starting point is to interpret a functional programming language as a suitable freely generated categorical structure. In this setting, by the universal property of the syntactic categorical structure, the dual numbers AD code transformation and the basic -cpo semantics arise as structure preserving functors. The proof follows, then, by a novel logical relations argument. The key to much of our contribution is a powerful monadic logical relations technique for term recursion and recursive types. It provides us with a semantic correctness proof based on a simple approach for denotational semantics, making use only of the very basic concrete model of -cpos.
Paper Structure (56 sections, 43 theorems, 175 equations, 16 figures)

This paper contains 56 sections, 43 theorems, 175 equations, 16 figures.

Key Result

theorem 1

For any program ${ x}:{ \tau}\vdash { t}:{ \sigma}$ for ${ \tau}=\mathbf{real}^{m},{ \sigma}=\mathbf{real}^l$ (where we write $\mathbf{real}^n$ for the type $\mathbf{real}+*\cdots+*\mathbf{real}$ of length $n$ tuples of reals), we have that $[\space[{ t}]\space]$ is differentiable on its for any $(x_1,\ldots,x_m)$ in the domain of definition of $[\space[{ t}]\space]$ and any tangent v

Figures (16)

  • Figure 5.1: Typing rules for a basic source language with real conditionals, where $\mathtt{R}\subseteq\mathbb{R}$ is a fixed set of real numbers containing $0$.
  • Figure 5.2: Typing rules for term recursion and iteration.
  • Figure 5.3: Basic $\beta\eta$-equational theory for our language. We write $\beta\eta$-equality as $\equiv$ to distinguish it from equality in let-bindings. We write $\overset{\# { x}_1,\ldots,{ x}_n}{\equiv}$ to indicate that the variables are fresh in the left hand side. In the top right rule, ${ x}$ may not be free in ${ r}$. Equations hold on pairs of computations of the same type.
  • Figure 5.4: Extra typing rules for the target language with iteration and recursion, where we denote $\mathbb{N} ^\ast := \mathbb{N} - \left\{ 0 \right\}$, $\mathbf{real} ^1 := \mathbf{real}$ and $\mathbf{real} ^{i+1} = \mathbf{real} ^i \times \mathbf{real}$.
  • Figure 5.5: Assignment that gives the universal property of the source language.
  • ...and 11 more figures

Theorems & Definitions (87)

  • theorem 1: Forward AD Correctness, Theorem \ref{['theo:main-theorem-section-proof']} with $k=1$ in main text
  • theorem 2: Logical relations for recursive types, special case of Theorem \ref{['the:maybe-the-main-result-on-LR-recursive']} in main text
  • theorem 3: Reverse AD Correctness, Theorem \ref{['theo:main-theorem-section-proof']} with $k=\infty$ in main text
  • definition 1: $CBV$ pair
  • Remark 4
  • definition 2: Free Recursion and Iteration
  • definition 3: $CBV$ model
  • definition 4: $CBV$ $\mathbf{\boldsymbol\omega Cpo}$-pair
  • lemma 1: Underlying $CBV$ model
  • proof
  • ...and 77 more