Table of Contents
Fetching ...

Does Flatness imply Generalization for Logistic Loss in Univariate Two-Layer ReLU Network?

Dan Qiao, Yu-Xiang Wang

TL;DR

The paper analyzes how flatness (minima stability) influences generalization for univariate, two-layer ReLU networks under logistic loss. It shows that flatness by itself does not guarantee generalization, constructing arbitrarily flat interpolants that overfit. However, with weight decay, flat solutions enjoy bounded generalization gaps, and under a weak generalization assumption, flatness can drive near-optimal excess risk within the convex hull of the ground-truth's uncertain regions via weighted TV(1) bounds. The authors derive both upper bounds and region-specific guarantees, and validate their theory with simulations that reveal the interplay between learning rate, flatness, and representation learning. Overall, the work clarifies when flatness helps generalization in logistic settings and highlights the significance of data-region considerations and training dynamics.

Abstract

We consider the problem of generalization of arbitrarily overparameterized two-layer ReLU Neural Networks with univariate input. Recent work showed that under square loss, flat solutions (motivated by flat / stable minima and Edge of Stability phenomenon) provably cannot overfit, but it remains unclear whether the same phenomenon holds for logistic loss. This is a puzzling open problem because existing work on logistic loss shows that gradient descent with increasing step size converges to interpolating solutions (at infinity, for the margin-separable cases). In this paper, we prove that the \emph{flatness implied generalization} is more delicate under logistic loss. On the positive side, we show that flat solutions enjoy near-optimal generalization bounds within a region between the left-most and right-most \emph{uncertain} sets determined by each candidate solution. On the negative side, we show that there exist arbitrarily flat yet overfitting solutions at infinity that are (falsely) certain everywhere, thus certifying that flatness alone is insufficient for generalization in general. We demonstrate the effects predicted by our theory in a well-controlled simulation study.

Does Flatness imply Generalization for Logistic Loss in Univariate Two-Layer ReLU Network?

TL;DR

The paper analyzes how flatness (minima stability) influences generalization for univariate, two-layer ReLU networks under logistic loss. It shows that flatness by itself does not guarantee generalization, constructing arbitrarily flat interpolants that overfit. However, with weight decay, flat solutions enjoy bounded generalization gaps, and under a weak generalization assumption, flatness can drive near-optimal excess risk within the convex hull of the ground-truth's uncertain regions via weighted TV(1) bounds. The authors derive both upper bounds and region-specific guarantees, and validate their theory with simulations that reveal the interplay between learning rate, flatness, and representation learning. Overall, the work clarifies when flatness helps generalization in logistic settings and highlights the significance of data-region considerations and training dynamics.

Abstract

We consider the problem of generalization of arbitrarily overparameterized two-layer ReLU Neural Networks with univariate input. Recent work showed that under square loss, flat solutions (motivated by flat / stable minima and Edge of Stability phenomenon) provably cannot overfit, but it remains unclear whether the same phenomenon holds for logistic loss. This is a puzzling open problem because existing work on logistic loss shows that gradient descent with increasing step size converges to interpolating solutions (at infinity, for the margin-separable cases). In this paper, we prove that the \emph{flatness implied generalization} is more delicate under logistic loss. On the positive side, we show that flat solutions enjoy near-optimal generalization bounds within a region between the left-most and right-most \emph{uncertain} sets determined by each candidate solution. On the negative side, we show that there exist arbitrarily flat yet overfitting solutions at infinity that are (falsely) certain everywhere, thus certifying that flatness alone is insufficient for generalization in general. We demonstrate the effects predicted by our theory in a well-controlled simulation study.

Paper Structure

This paper contains 43 sections, 31 theorems, 151 equations, 17 figures.

Key Result

Theorem 3.1

For the example above, there exists a choice of $\theta$ such that $f_\theta(x_i)=y_i\gamma_{\max}$ for all $i\in[n]$, $\mathcal{L}(\theta)$ is twice differentiable w.r.t. $\theta$ and $\lambda_{\max}\left(\nabla_\theta^2 \mathcal{L}(\theta)\right)\leq O\left((n^2 \gamma_{\max}+1) e^{-\gamma_{\max}}

Figures (17)

  • Figure 1: The left panel summarizes our findings about flatness, generalization and interpolation in logistic regression. The middle panel compares the learned function by GD with large and small learning rates, the stable solution of large learning rate is simpler and smoother. The right panel provides an illustration for the "uncertain region" (red part) of the function and the weight function (the $h$ function in \ref{['equ:bias']} and \ref{['equ:tvbmain']}) supported in the interior of uncertain regions. Briefly speaking, a larger weight function poses stronger smoothness guarantee in the corresponding region. Here we plot the asymptotic weight function in Theorem \ref{['thm:tvb']} with $\gamma=1.5,\zeta=0.3$ and $\mathcal{P}_x=\text{Unif}([-2,2])$. The weight function in Theorem \ref{['thm:bias']} can be derived identically by replacing $f_0$ with $f_\theta$.
  • Figure 2: Highlight of our empirical results. The left panel illustrates the learned functions for GD with large (0.8 or 1) and small (0.01) learning rates using different numbers of samples. The middle panel plots the impact of varying learning rate on the complexity and performance of the learned function. The right panel showcases the following relationships: (1) TV$^{(1)}$ norm vs number of data and (2) excess risk vs number of data under several fixed choices of learning rate.
  • Figure 3: Illustration of the solutions gradient descent with learning rate $\eta$ converges to ($n=80$: Part I). As $\eta$ decreases, the fitted function goes from simple to complex. Any line below the $\mathcal{L}(f_0)$ line satisfies the “optimized” assumption from Theorem \ref{['thm:erbound']} and Lemma \ref{['lem:erbound']}. Test loss denotes $\bar{\mathcal{L}}(f)$.
  • Figure 4: Illustrations ($n=80$: Part II). As $\eta$ decreases further, the fitted function starts to overfit.
  • Figure 5: Illustration of the solutions GD with learning rate $\eta$ converges to ($n=160$: Part I).
  • ...and 12 more figures

Theorems & Definitions (58)

  • Theorem 3.1
  • Theorem 3.2
  • Corollary 3.3: Corollary of Theorem \ref{['thm:erbound']}
  • Theorem 3.5
  • Theorem 3.6
  • Theorem 3.7
  • Definition B.1: Linear stability
  • Lemma B.2
  • proof : Proof of Lemma \ref{['lem:stable']}
  • Lemma B.3
  • ...and 48 more