Table of Contents
Fetching ...

SoftJAX & SoftTorch: Empowering Automatic Differentiation Libraries with Informative Gradients

Anselm Paulus, A. René Geist, Vít Musil, Sebastian Hoffmann, Onur Beker, Georg Martius

TL;DR

This work introduces SoftJAX and SoftTorch, open-source, feature-complete libraries for soft differentiable programming, which provide a variety of soft functions as drop-in replacements for their hard JAX and PyTorch counterparts.

Abstract

Automatic differentiation (AD) frameworks such as JAX and PyTorch have enabled gradient-based optimization for a wide range of scientific fields. Yet, many "hard" primitives in these libraries such as thresholding, Boolean logic, discrete indexing, and sorting operations yield zero or undefined gradients that are not useful for optimization. While numerous "soft" relaxations have been proposed that provide informative gradients, the respective implementations are fragmented across projects, making them difficult to combine and compare. This work introduces SoftJAX and SoftTorch, open-source, feature-complete libraries for soft differentiable programming. These libraries provide a variety of soft functions as drop-in replacements for their hard JAX and PyTorch counterparts. This includes (i) elementwise operators such as clip or abs, (ii) utility methods for manipulating Booleans and indices via fuzzy logic, (iii) axiswise operators such as sort or rank -- based on optimal transport or permutahedron projections, and (iv) offer full support for straight-through gradient estimation. Overall, SoftJAX and SoftTorch make the toolbox of soft relaxations easily accessible to differentiable programming, as demonstrated through benchmarking and a practical case study. Code is available at github.com/a-paulus/softjax and github.com/a-paulus/softtorch.

SoftJAX & SoftTorch: Empowering Automatic Differentiation Libraries with Informative Gradients

TL;DR

This work introduces SoftJAX and SoftTorch, open-source, feature-complete libraries for soft differentiable programming, which provide a variety of soft functions as drop-in replacements for their hard JAX and PyTorch counterparts.

Abstract

Automatic differentiation (AD) frameworks such as JAX and PyTorch have enabled gradient-based optimization for a wide range of scientific fields. Yet, many "hard" primitives in these libraries such as thresholding, Boolean logic, discrete indexing, and sorting operations yield zero or undefined gradients that are not useful for optimization. While numerous "soft" relaxations have been proposed that provide informative gradients, the respective implementations are fragmented across projects, making them difficult to combine and compare. This work introduces SoftJAX and SoftTorch, open-source, feature-complete libraries for soft differentiable programming. These libraries provide a variety of soft functions as drop-in replacements for their hard JAX and PyTorch counterparts. This includes (i) elementwise operators such as clip or abs, (ii) utility methods for manipulating Booleans and indices via fuzzy logic, (iii) axiswise operators such as sort or rank -- based on optimal transport or permutahedron projections, and (iv) offer full support for straight-through gradient estimation. Overall, SoftJAX and SoftTorch make the toolbox of soft relaxations easily accessible to differentiable programming, as demonstrated through benchmarking and a practical case study. Code is available at github.com/a-paulus/softjax and github.com/a-paulus/softtorch.
Paper Structure (52 sections, 1 theorem, 84 equations, 20 figures, 2 tables)

This paper contains 52 sections, 1 theorem, 84 equations, 20 figures, 2 tables.

Key Result

Theorem 1

Let $n \geq 2$, $1 < p \leq 2$, $\tau > 0$, and let $q = p/(p{-}1)$ be the conjugate exponent with $q$ a positive integer. Define $k \coloneqq q - 2$. Moreover, the bound $\mathcal{C}^k$ is sharp: $\Pi_\tau$ is not $\mathcal{C}^{k+1}$, and $\Gamma_\tau^\star$ is not $\mathcal{C}^{k+1}$. In particular, c0 ($p\!=\!2$) gives $\mathcal{C}^0$, c1 ($p\!=\!3/2$) gives $\mathcal{C}^1$, c2 ($p\!=\!4/3$) g

Figures (20)

  • Figure 1: Top: SoftJAX and SoftTorch provide differentiable surrogates for discrete operations i. e., $\operatorname{rank}_{\tau}$ instead of $\operatorname{rank}$. Numerous soft approximations can be obtained by tuning the softness parameter $\tau$, selecting the softening method (e. g., "neuralsort" or "softsort"), and choosing a smoothness mode (e. g., smooth for $\mathcal{C}^\infty$ or c1 for $\mathcal{C}^1$). Bottom: The soft surrogates shown above are instantiated using just three lines of code.
  • Figure 2: The function relu_st resorts to the straight-through trick to use the hard function jax.nn.relu in the forward pass while using the soft function sj.relu for gradient computation.
  • Figure 3: Arrows depict the normalized gradient of the product of two functions $f(x,y)=\operatorname{relu}(x)$ and $g(x,y)=\operatorname{relu}(y)$. Applying straight-through estimation on each function individually causes the gradient $\nabla (f_{\text{STE}} \cdot g_{\text{STE}})$ to be zero if $x<0$, whereas the gradient $\nabla (f \cdot g)_{\text{STE}}$ is non-zero.
  • Figure 4: Top: The heaviside function surrogates $H_{\tau}$ are used to derive the $\operatorname{sign}$, $\operatorname{round}$, $\operatorname{abs}$, and $\operatorname{clip}$ functions. Bottom: Comparison operations compare values $x,y\in\mathbb{R}$ by interpreting $H_{\tau}$ as a CDF defining the probability that $x>0$.
  • Figure 5: OT-based $\mathop{\mathrm{arg\,topk}}\nolimits$ computes a (scaled) optimal transport plan $P^{\star}_{\tau}$ whose entries $P_{ij}^\star$ contain the probability mass transported from the $j$-th entry of $\mathbf{x}$ to the $i$-th entry of the anchor $\mathbf{y}$.
  • ...and 15 more figures

Theorems & Definitions (2)

  • Theorem 1: Smoothness of $p$-norm regularized projections
  • proof