Table of Contents
Fetching ...

Learning Set Functions with Implicit Differentiation

Gözde Özcan, Chengzhi Shi, Stratis Ioannidis

TL;DR

The paper tackles learning set functions from data generated by an optimal-subset oracle by casting the problem as learning an energy-based model with a mean-field variational approximation. It derives a fixed-point equation for the mean-field parameters and provides convergence guarantees under a bounded multilinear relaxation, using Banach’s fixed-point theorem. To address computational bottlenecks, the authors leverage implicit differentiation (via the implicit function theorem) to compute gradients without unrolling the fixed-point iterations, yielding the iDiffMF method. Empirically, iDiffMF achieves competitive or superior Jaccard scores on diverse subset-selection tasks while reducing memory usage and maintaining efficient training, demonstrating practical impact for real-world set-function learning.

Abstract

Ou et al. (2022) introduce the problem of learning set functions from data generated by a so-called optimal subset oracle. Their approach approximates the underlying utility function with an energy-based model, whose parameters are estimated via mean-field variational inference. Ou et al. (2022) show this reduces to fixed point iterations; however, as the number of iterations increases, automatic differentiation quickly becomes computationally prohibitive due to the size of the Jacobians that are stacked during backpropagation. We address this challenge with implicit differentiation and examine the convergence conditions for the fixed-point iterations. We empirically demonstrate the efficiency of our method on synthetic and real-world subset selection applications including product recommendation, set anomaly detection and compound selection tasks.

Learning Set Functions with Implicit Differentiation

TL;DR

The paper tackles learning set functions from data generated by an optimal-subset oracle by casting the problem as learning an energy-based model with a mean-field variational approximation. It derives a fixed-point equation for the mean-field parameters and provides convergence guarantees under a bounded multilinear relaxation, using Banach’s fixed-point theorem. To address computational bottlenecks, the authors leverage implicit differentiation (via the implicit function theorem) to compute gradients without unrolling the fixed-point iterations, yielding the iDiffMF method. Empirically, iDiffMF achieves competitive or superior Jaccard scores on diverse subset-selection tasks while reducing memory usage and maintaining efficient training, demonstrating practical impact for real-world set-function learning.

Abstract

Ou et al. (2022) introduce the problem of learning set functions from data generated by a so-called optimal subset oracle. Their approach approximates the underlying utility function with an energy-based model, whose parameters are estimated via mean-field variational inference. Ou et al. (2022) show this reduces to fixed point iterations; however, as the number of iterations increases, automatic differentiation quickly becomes computationally prohibitive due to the size of the Jacobians that are stacked during backpropagation. We address this challenge with implicit differentiation and examine the convergence conditions for the fixed-point iterations. We empirically demonstrate the efficiency of our method on synthetic and real-world subset selection applications including product recommendation, set anomaly detection and compound selection tasks.

Paper Structure

This paper contains 49 sections, 10 theorems, 45 equations, 1 figure, 7 tables, 2 algorithms.

Key Result

Theorem 4.2

Assume a set function $F_{\boldsymbol{\theta}}: 2^V \rightarrow \mathbb{R}$ satisfies Asm. asm:bound. Then, the fixed-point given in Eq. eq:fixed_point has a unique solution $\boldsymbol{\psi}^* \in [0, 1]^{|V|}$ where $\boldsymbol{\psi}^* = \boldsymbol{\sigma}(\nabla_{\boldsymbol{\psi}} \Tilde{F} (

Figures (1)

  • Figure 1: Effects of the choice of differentiation method on the relationship between the allocated GPU memory and the number of fixed-point iterations across different datasets. Blue lines represent automatic differentiation ($\mathtt{DiffMF}$), while the orange lines represent implicit differentiation ($\mathtt{iDiffMF}$). The markers denote the average memory usage. The area between the recorded minimum and maximum memory usage is shaded.

Theorems & Definitions (18)

  • Theorem 4.2
  • Theorem 4.3
  • Definition A.1
  • Definition A.2
  • Theorem A.3
  • proof
  • Lemma C.1
  • proof
  • Corollary C.2
  • Theorem D.1
  • ...and 8 more