Table of Contents
Fetching ...

Binary Search with Distributional Predictions

Michael Dinitz, Sungjin Im, Thomas Lavastida, Benjamin Moseley, Aidin Niaparast, Sergei Vassilvitskii

TL;DR

This work starts the study of algorithms with distributional predictions, where the prediction itself is a distribution, and gives an algorithm with query complexity O(H(p) + \log \eta)$, which is the first distributionally-robust algorithm for the classical problem of computing an optimal binary search tree given a distribution over target keys.

Abstract

Algorithms with (machine-learned) predictions is a powerful framework for combining traditional worst-case algorithms with modern machine learning. However, the vast majority of work in this space assumes that the prediction itself is non-probabilistic, even if it is generated by some stochastic process (such as a machine learning system). This is a poor fit for modern ML, particularly modern neural networks, which naturally generate a distribution. We initiate the study of algorithms with distributional predictions, where the prediction itself is a distribution. We focus on one of the simplest yet fundamental settings: binary search (or searching a sorted array). This setting has one of the simplest algorithms with a point prediction, but what happens if the prediction is a distribution? We show that this is a richer setting: there are simple distributions where using the classical prediction-based algorithm with any single prediction does poorly. Motivated by this, as our main result, we give an algorithm with query complexity $O(H(p) + \log η)$, where $H(p)$ is the entropy of the true distribution $p$ and $η$ is the earth mover's distance between $p$ and the predicted distribution $\hat p$. This also yields the first distributionally-robust algorithm for the classical problem of computing an optimal binary search tree given a distribution over target keys. We complement this with a lower bound showing that this query complexity is essentially optimal (up to constants), and experiments validating the practical usefulness of our algorithm.

Binary Search with Distributional Predictions

TL;DR

This work starts the study of algorithms with distributional predictions, where the prediction itself is a distribution, and gives an algorithm with query complexity O(H(p) + \log \eta)$, which is the first distributionally-robust algorithm for the classical problem of computing an optimal binary search tree given a distribution over target keys.

Abstract

Algorithms with (machine-learned) predictions is a powerful framework for combining traditional worst-case algorithms with modern machine learning. However, the vast majority of work in this space assumes that the prediction itself is non-probabilistic, even if it is generated by some stochastic process (such as a machine learning system). This is a poor fit for modern ML, particularly modern neural networks, which naturally generate a distribution. We initiate the study of algorithms with distributional predictions, where the prediction itself is a distribution. We focus on one of the simplest yet fundamental settings: binary search (or searching a sorted array). This setting has one of the simplest algorithms with a point prediction, but what happens if the prediction is a distribution? We show that this is a richer setting: there are simple distributions where using the classical prediction-based algorithm with any single prediction does poorly. Motivated by this, as our main result, we give an algorithm with query complexity , where is the entropy of the true distribution and is the earth mover's distance between and the predicted distribution . This also yields the first distributionally-robust algorithm for the classical problem of computing an optimal binary search tree given a distribution over target keys. We complement this with a lower bound showing that this query complexity is essentially optimal (up to constants), and experiments validating the practical usefulness of our algorithm.

Paper Structure

This paper contains 25 sections, 4 theorems, 8 equations, 4 figures.

Key Result

Theorem 1

The expected query complexity of the described algorithm is at most $4H(p)+8\max(\log(\eta)+2,1)+8 =O(H(p)+\max(\log(\eta),0))$.To account for the case where $\eta \in [0,1)$ where $\log(\eta) < 0$, we impose a bound by taking the maximum of $\log(\eta)$ and 0.

Figures (4)

  • Figure 1: Results for synthetic data experiments. The y-axis measures the average cost (query complexity) of each algorithm and the x-axis measures the amount of shift in the test distribution. The training and test data are regenerated 5 times. The solid lines are the mean and the clouds around them are the standard deviation of these experiments.
  • Figure 2: The train and test distributions when $t=50$ for the three datasets.
  • Figure 3: Results for real data experiments. The y-axis measures the average cost of each algorithm and the x-axis indicates the fraction of the dataset used for training
  • Figure 4: Results for real data experiments. The y-axis measures the average cost of each algorithm and the x-axis indicates the logarithm of the earth mover's distance between $\hat{p}$ and $p$.

Theorems & Definitions (7)

  • Theorem 1
  • proof : Proof of Theorem \ref{['thm:analysis']}
  • Theorem 2
  • proof
  • Corollary 3
  • Theorem 4
  • proof : Proof of Theorem \ref{['thm:analysis-multiple']}