Automatic Functional Differentiation in JAX
Min Lin
TL;DR
The paper addresses the lack of automatic differentiation tools for functionals by reusing JAX's AD machinery to perform automatic functional differentiation (AutoFD). It introduces a generalized-array representation of functions and a core set of primitives—$\hat{C}$ (compose), $\nabla$, $\hat{L}$ (linearize), $\hat{T}$ (linear transpose), and $\hat{I}$ (integrate)—each endowed with forward $JVP$ and transpose rules to enable higher-order differentiation and to yield functional gradients as callable Python functions. The work discusses completeness, numerical integration via grids, and practical considerations, and demonstrates applications in variational problems and density-functional theory. By enabling functional derivatives directly in the same syntax used for functions, AutoFD aims to simplify the development of function-space optimization methods and neural-operator frameworks. The approach provides a pathway to differentiating higher-order constructs such as integral functionals and neural operators without symbolic manipulation.
Abstract
We extend JAX with the capability to automatically differentiate higher-order functions (functionals and operators). By representing functions as a generalization of arrays, we seamlessly use JAX's existing primitive system to implement higher-order functions. We present a set of primitive operators that serve as foundational building blocks for constructing several key types of functionals. For every introduced primitive operator, we derive and implement both linearization and transposition rules, aligning with JAX's internal protocols for forward and reverse mode automatic differentiation. This enhancement allows for functional differentiation in the same syntax traditionally use for functions. The resulting functional gradients are themselves functions ready to be invoked in python. We showcase this tool's efficacy and simplicity through applications where functional derivatives are indispensable. The source code of this work is released at https://github.com/sail-sg/autofd .
