Compiling to linear neurons
Joey Velez-Ginorio, Nada Amin, Konrad Kording, Steve Zdancewic
TL;DR
This work tackles the challenge of directly programming neural networks by bridging discrete programming with differentiable learning through a linear, typed language called Cajal. It shows how Cajal programs compile to linear neurons, enabling discrete algorithms to be expressed in differentiable form, and provides a formal compiler-correctness framework along with a practical embedding in PyTorch. Through two experiments on conditional image transformation and classification, the authors demonstrate that directly programmed networks can learn faster and more data-efficiently than traditional indirectly programmed networks, with clear debugging advantages when linking to nonlinear components. The study highlights a promising interplay between learning and discrete programming structures, while outlining important limitations and avenues for future expansion, including richer type systems and recurrence for infinite data.
Abstract
We don't program neural networks directly. Instead, we rely on an indirect style where learning algorithms, like gradient descent, determine a neural network's function by learning from data. This indirect style is often a virtue; it empowers us to solve problems that were previously impossible. But it lacks discrete structure. We can't compile most algorithms into a neural network -- even if these algorithms could help the network learn. This limitation occurs because discrete algorithms are not obviously differentiable, making them incompatible with the gradient-based learning algorithms that determine a neural network's function. To address this, we introduce $\textsf{Cajal}$: a typed, higher-order and linear programming language intended to be a minimal vehicle for exploring a direct style of programming neural networks. We prove $\textsf{Cajal}$ programs compile to linear neurons, allowing discrete algorithms to be expressed in a differentiable form compatible with gradient-based learning. With our implementation of $\textsf{Cajal}$, we conduct several experiments where we link these linear neurons against other neural networks to determine part of their function prior to learning. Linking with these neurons allows networks to learn faster, with greater data-efficiency, and in a way that's easier to debug. A key lesson is that linear programming languages provide a path towards directly programming neural networks, enabling a rich interplay between learning and the discrete structures of ordinary programming.
