Table of Contents
Fetching ...

The Implicit Bias of Gradient Descent on Separable Multiclass Data

Hrithik Ravi, Clayton Scott, Daniel Soudry, Yutong Wang

TL;DR

This work employs the framework of Permutation Equivariant and Relative Margin-based (PERM) losses to introduce a multiclass extension of the exponential tail property, and extends the implicit bias result of Soudry et al.

Abstract

Implicit bias describes the phenomenon where optimization-based training algorithms, without explicit regularization, show a preference for simple estimators even when more complex estimators have equal objective values. Multiple works have developed the theory of implicit bias for binary classification under the assumption that the loss satisfies an exponential tail property. However, there is a noticeable gap in analysis for multiclass classification, with only a handful of results which themselves are restricted to the cross-entropy loss. In this work, we employ the framework of Permutation Equivariant and Relative Margin-based (PERM) losses [Wang and Scott, 2024] to introduce a multiclass extension of the exponential tail property. This class of losses includes not only cross-entropy but also other losses. Using this framework, we extend the implicit bias result of Soudry et al. [2018] to multiclass classification. Furthermore, our proof techniques closely mirror those of the binary case, thus illustrating the power of the PERM framework for bridging the binary-multiclass gap.

The Implicit Bias of Gradient Descent on Separable Multiclass Data

TL;DR

This work employs the framework of Permutation Equivariant and Relative Margin-based (PERM) losses to introduce a multiclass extension of the exponential tail property, and extends the implicit bias result of Soudry et al.

Abstract

Implicit bias describes the phenomenon where optimization-based training algorithms, without explicit regularization, show a preference for simple estimators even when more complex estimators have equal objective values. Multiple works have developed the theory of implicit bias for binary classification under the assumption that the loss satisfies an exponential tail property. However, there is a noticeable gap in analysis for multiclass classification, with only a handful of results which themselves are restricted to the cross-entropy loss. In this work, we employ the framework of Permutation Equivariant and Relative Margin-based (PERM) losses [Wang and Scott, 2024] to introduce a multiclass extension of the exponential tail property. This class of losses includes not only cross-entropy but also other losses. Using this framework, we extend the implicit bias result of Soudry et al. [2018] to multiclass classification. Furthermore, our proof techniques closely mirror those of the binary case, thus illustrating the power of the PERM framework for bridging the binary-multiclass gap.

Paper Structure

This paper contains 47 sections, 19 theorems, 152 equations, 3 figures.

Key Result

Theorem 2.1

Let $\mathcal{L} : \mathbb{R}^K \rightarrow \mathbb{R}^K$ be a PERM loss with template $\psi$, and let $v \in \mathbb{R}^K$ and $y \in [K]$ be arbitrary. Then $\psi$ is a symmetric function. Moreover, Conversely, let $\psi : \mathbb{R}^{K-1} \rightarrow \mathbb{R}$ be a symmetric function. Define a multiclass loss function $\mathcal{L} = (\mathcal{L}_1, \dots , \mathcal{L}_k) : \mathbb{R}^K \righ

Figures (3)

  • Figure 1: An illustration of the exponential tail property for the cross entropy/multinomial logistic loss when $K=3$. Panel a. Plot of $\psi(\mathbf{u})=\log(1+ \exp(-u_1) + \exp(-u_2))$, the template for the multinomial logistic loss. Note that the complement of the positive orthant in the domain $\mathbb{R}^2$ is shown in gray. Panel b. and c. Plot of the upper bound (shown in black) and lower bounds (red) of $-\frac{\partial \psi}{\partial u_1}$ (blue) respectively. These bounds are from Appendix \ref{['appendix: CE_exp_tail']} where $u_\pm = 0$ and $c=1$. Note that the lower bound is valid in the positive orthant, i.e., the red surface is below the blue one there.
  • Figure 2: Small simulation with $N=10$, $d = 2$ and $K=3$. The loss used is the "PairLogLoss". Top row. Decision regions of classifiers along the gradient path $\mathbf{w}(t)$ at $t = 100,\,1000$, and $100000$, respectively from left to right. Bottom row. Decision regions of the hard-margin multiclass SVM. Note that most of the progress is made between iterations 100 and 1000.
  • Figure 3: Large simulations with $N=100$, $d = 10$ and $K=3$. The loss used is the "PairLogLoss". The curves are 10 independent runs with randomly sampled data and random initialization for gradient descent over $100000$ iterations. Note that the convergence in direction of the gradient descent iterates to the hard-margin SVM slows down in log-log space.

Theorems & Definitions (37)

  • Definition 2.1: PERM loss wang2023unified
  • Theorem 2.1: wang2023unified
  • Definition 2.2: Multiclass exponential tail property
  • Remark 2.2
  • Theorem 3.4
  • Lemma 4.2
  • Lemma 4.3
  • Lemma 4.4
  • proof
  • Remark 4.5
  • ...and 27 more