Table of Contents
Fetching ...

FIRAL: An Active Learning Algorithm for Multinomial Logistic Regression

Youguang Chen, George Biros

TL;DR

FIRAL addresses pool-based active learning for multiclass classification with multinomial logistic regression by tying excess-risk control to the Fisher Information Ratio (FIR) between the unlabeled distribution and a chosen sampling distribution. It introduces a two-stage FIR-focused method: first a convex relaxation to minimize FIR, then a regret-minimization (FTRL) rounding to select the actual labeled points, with provable $(1+ ext{ε})$-approximation guarantees. Theoretical results provide finite-sample upper and lower bounds on the excess risk in terms of FIR under sub-Gaussian assumptions, complemented by bounded-domain analyses. Empirical results on MNIST, CIFAR-10, and ImageNet-50 show that FIRAL consistently outperforms several baselines, particularly in low-sample regimes, underscoring its practical impact for efficient multiclass active learning.

Abstract

We investigate theory and algorithms for pool-based active learning for multiclass classification using multinomial logistic regression. Using finite sample analysis, we prove that the Fisher Information Ratio (FIR) lower and upper bounds the excess risk. Based on our theoretical analysis, we propose an active learning algorithm that employs regret minimization to minimize the FIR. To verify our derived excess risk bounds, we conduct experiments on synthetic datasets. Furthermore, we compare FIRAL with five other methods and found that our scheme outperforms them: it consistently produces the smallest classification error in the multiclass logistic regression setting, as demonstrated through experiments on MNIST, CIFAR-10, and 50-class ImageNet.

FIRAL: An Active Learning Algorithm for Multinomial Logistic Regression

TL;DR

FIRAL addresses pool-based active learning for multiclass classification with multinomial logistic regression by tying excess-risk control to the Fisher Information Ratio (FIR) between the unlabeled distribution and a chosen sampling distribution. It introduces a two-stage FIR-focused method: first a convex relaxation to minimize FIR, then a regret-minimization (FTRL) rounding to select the actual labeled points, with provable -approximation guarantees. Theoretical results provide finite-sample upper and lower bounds on the excess risk in terms of FIR under sub-Gaussian assumptions, complemented by bounded-domain analyses. Empirical results on MNIST, CIFAR-10, and ImageNet-50 show that FIRAL consistently outperforms several baselines, particularly in low-sample regimes, underscoring its practical impact for efficient multiclass active learning.

Abstract

We investigate theory and algorithms for pool-based active learning for multiclass classification using multinomial logistic regression. Using finite sample analysis, we prove that the Fisher Information Ratio (FIR) lower and upper bounds the excess risk. Based on our theoretical analysis, we propose an active learning algorithm that employs regret minimization to minimize the FIR. To verify our derived excess risk bounds, we conduct experiments on synthetic datasets. Furthermore, we compare FIRAL with five other methods and found that our scheme outperforms them: it consistently produces the smallest classification error in the multiclass logistic regression setting, as demonstrated through experiments on MNIST, CIFAR-10, and 50-class ImageNet.
Paper Structure (61 sections, 37 theorems, 253 equations, 16 figures, 3 algorithms)

This paper contains 61 sections, 37 theorems, 253 equations, 16 figures, 3 algorithms.

Key Result

Lemma 2

If Assumption assume:sub-gaussian holds for $q(x)$, then for $(x,y) \sim \pi_q(x,y)$:

Figures (16)

  • Figure 1: FIR prefactors in \ref{['eq:sub-thm-risk']}.
  • Figure 2: Synthetic experiments: excess risk of $p(x)$ as a function of the FIR (${{\bf H}_q}^{-1}\cdot {{\bf H}_p}$) in dilation and translation tests.
  • Figure 3: Active learning results for MNIST (left) , CIFAR-10 (center) and ImageNet-50 (right). Black dashed lines in the upper row plots are the classification accuracy using all points in $U$ and their labels. The lower row shows 50 images that are selected in the first round of the active learning process for the ImageNet-50 dataset.
  • Figure 4: Plots of first two coordinates of points draw from the joint distribution $pi_p(x,y)$.
  • Figure 5: Excess risk of $q(x)$ as a function of $n$, $d$ and $c-1$. The dashed black line in the left plot indicates inversely linear relation. The dashed black lines in the center and right plots indicate linear relations.
  • ...and 11 more figures

Theorems & Definitions (69)

  • Lemma 2
  • Theorem 3
  • Theorem 4
  • Lemma 5
  • Proposition 6
  • Proposition 7
  • Proposition 8
  • Proposition 9
  • Theorem 10
  • Definition 11: Sub-Gaussian random variable
  • ...and 59 more