Table of Contents
Fetching ...

Determinantal Point Processes for Mini-Batch Diversification

Cheng Zhang, Hedvig Kjellstrom, Stephan Mandt

TL;DR

Determinantal Point Processes for Mini-Batch Diversification introduces DM-SGD, a biased mini-batch sampling method that uses $k$-DPPs to produce diverse mini-batches based on a data similarity kernel. The method generalizes biased stratified sampling, enables variance reduction of stochastic gradients, and can be precomputed. Theoretical analysis shows DM-SGD corresponds to a reweighted empirical risk with DPP marginals and provides conditions under which gradient variance is reduced. Empirically, DM-SGD improves topic diversity in LDA and yields higher classification accuracy for supervised tasks, including CNNs on MNIST and multiclass logistic regression, especially under data imbalance. This approach offers a general, objective-agnostic way to balance data subsets and accelerate learning in settings with expensive gradient updates.

Abstract

We study a mini-batch diversification scheme for stochastic gradient descent (SGD). While classical SGD relies on uniformly sampling data points to form a mini-batch, we propose a non-uniform sampling scheme based on the Determinantal Point Process (DPP). The DPP relies on a similarity measure between data points and gives low probabilities to mini-batches which contain redundant data, and higher probabilities to mini-batches with more diverse data. This simultaneously balances the data and leads to stochastic gradients with lower variance. We term this approach Diversified Mini-Batch SGD (DM-SGD). We show that regular SGD and a biased version of stratified sampling emerge as special cases. Furthermore, DM-SGD generalizes stratified sampling to cases where no discrete features exist to bin the data into groups. We show experimentally that our method results more interpretable and diverse features in unsupervised setups, and in better classification accuracies in supervised setups.

Determinantal Point Processes for Mini-Batch Diversification

TL;DR

Determinantal Point Processes for Mini-Batch Diversification introduces DM-SGD, a biased mini-batch sampling method that uses -DPPs to produce diverse mini-batches based on a data similarity kernel. The method generalizes biased stratified sampling, enables variance reduction of stochastic gradients, and can be precomputed. Theoretical analysis shows DM-SGD corresponds to a reweighted empirical risk with DPP marginals and provides conditions under which gradient variance is reduced. Empirically, DM-SGD improves topic diversity in LDA and yields higher classification accuracy for supervised tasks, including CNNs on MNIST and multiclass logistic regression, especially under data imbalance. This approach offers a general, objective-agnostic way to balance data subsets and accelerate learning in settings with expensive gradient updates.

Abstract

We study a mini-batch diversification scheme for stochastic gradient descent (SGD). While classical SGD relies on uniformly sampling data points to form a mini-batch, we propose a non-uniform sampling scheme based on the Determinantal Point Process (DPP). The DPP relies on a similarity measure between data points and gives low probabilities to mini-batches which contain redundant data, and higher probabilities to mini-batches with more diverse data. This simultaneously balances the data and leads to stochastic gradients with lower variance. We term this approach Diversified Mini-Batch SGD (DM-SGD). We show that regular SGD and a biased version of stratified sampling emerge as special cases. Furthermore, DM-SGD generalizes stratified sampling to cases where no discrete features exist to bin the data into groups. We show experimentally that our method results more interpretable and diverse features in unsupervised setups, and in better classification accuracies in supervised setups.

Paper Structure

This paper contains 24 sections, 5 theorems, 22 equations, 11 figures, 5 tables, 3 algorithms.

Key Result

Proposition 1

Biased stratified sampling (StS) zhao2014accelerating, where data from different strata are subsampled with equal probability, is equivalent to DM-SGD with a similarity matrix $L$, defined as a block-diagonal matrix with where $H_i$ denotes the label for the stratum of data point $i$.

Figures (11)

  • Figure 1: Sampling mini-batches using the $k$-DPP. For an imbalanced dataset, our method results in diversified mini-batches.
  • Figure 2: Visualization of different non-uniform data subsampling schemes on toy data. Panel (a) shows a homogeneous distribution of data. We assume that we only observe an imbalanced subset, shown in panel (b). Panels (c), (d), (e) demonstrate different biased sampling methods that aim at restoring balance in the data. Thicknesses of data points thereby indicate their sampling frequency. Biased stratified sampling (c) relies on dividing the feature space vertically along certain dimensions, whereas pre-clustering (d) defines the strata as clusters obtained from k-means fu2017CPSGMCMC (we used $k=4$). The black diamonds show the cluster centers and data are colored with respect to their cluster membership. Panel (e) shows the results using the $k$-DPP, using an RPF kernel of spatial distances as similarity measure between data points. In this example, the $k$-DPP best restores the balance of the original data set.
  • Figure 3: Per topic word distribution for the synthetic data. Each row presents a topic and each column presents a word. (a) shows the ground truth with which the synthetic data is generated using LDA. (b) shows the estimation of this latent variable with LDA using traditional stochastic variational inference (SVI). (c) shows the estimation of this latent variable with DM-SVI
  • Figure 4: The frequency of class labels of the training dataset (in blue) and of the balanced dataset (in yellow). While explicit class label information is withheld, the algorithm partially balances class contributions.
  • Figure 5: Confusion matrix for text classification based on LDA features obtained from SVI (a) and the proposed DM-SVI (b). DM-SVI features lead to better accuracies.
  • ...and 6 more figures

Theorems & Definitions (8)

  • Proposition 1
  • proof
  • Proposition 2
  • Proposition 3
  • proof
  • Proposition 4
  • Theorem 1
  • proof