Learning Set Functions with Implicit Differentiation
Gözde Özcan, Chengzhi Shi, Stratis Ioannidis
TL;DR
The paper tackles learning set functions from data generated by an optimal-subset oracle by casting the problem as learning an energy-based model with a mean-field variational approximation. It derives a fixed-point equation for the mean-field parameters and provides convergence guarantees under a bounded multilinear relaxation, using Banach’s fixed-point theorem. To address computational bottlenecks, the authors leverage implicit differentiation (via the implicit function theorem) to compute gradients without unrolling the fixed-point iterations, yielding the iDiffMF method. Empirically, iDiffMF achieves competitive or superior Jaccard scores on diverse subset-selection tasks while reducing memory usage and maintaining efficient training, demonstrating practical impact for real-world set-function learning.
Abstract
Ou et al. (2022) introduce the problem of learning set functions from data generated by a so-called optimal subset oracle. Their approach approximates the underlying utility function with an energy-based model, whose parameters are estimated via mean-field variational inference. Ou et al. (2022) show this reduces to fixed point iterations; however, as the number of iterations increases, automatic differentiation quickly becomes computationally prohibitive due to the size of the Jacobians that are stacked during backpropagation. We address this challenge with implicit differentiation and examine the convergence conditions for the fixed-point iterations. We empirically demonstrate the efficiency of our method on synthetic and real-world subset selection applications including product recommendation, set anomaly detection and compound selection tasks.
