Table of Contents
Fetching ...

Flat Minima and Generalization: Insights from Stochastic Convex Optimization

Matan Schliserman, Shira Vansover-Hager, Tomer Koren

TL;DR

This work studies the link between flat minima and generalization in the canonical setting of stochastic convex optimization with a non-negative, $\beta$-smooth objective and finds that, even in this fundamental and well-studied setting, flat empirical minima may incur trivial $\Omega(1)$ population risk while sharp minima generalizes optimally.

Abstract

Understanding the generalization behavior of learning algorithms is a central goal of learning theory. A recently emerging explanation is that learning algorithms are successful in practice because they converge to flat minima, which have been consistently associated with improved generalization performance. In this work, we study the link between flat minima and generalization in the canonical setting of stochastic convex optimization with a non-negative, $β$-smooth objective. Our first finding is that, even in this fundamental and well-studied setting, flat empirical minima may incur trivial $Ω(1)$ population risk while sharp minima generalizes optimally. Then, we show that this poor generalization behavior extends to two natural ''sharpness-aware'' algorithms originally proposed by Foret et al. (2021), designed to bias optimization toward flat solutions: Sharpness-Aware Gradient Descent (SA-GD) and Sharpness-Aware Minimization (SAM). For SA-GD, which performs gradient steps on the maximal loss in a predefined neighborhood, we prove that while it successfully converges to a flat minimum at a fast rate, the population risk of the solution can still be as large as $Ω(1)$, indicating that even flat minima found algorithmically using a sharpness-aware gradient method might generalize poorly. For SAM, a computationally efficient approximation of SA-GD based on normalized ascent steps, we show that although it minimizes the empirical loss, it may converge to a sharp minimum and also incur population risk $Ω(1)$. Finally, we establish population risk upper bounds for both SA-GD and SAM using algorithmic stability techniques.

Flat Minima and Generalization: Insights from Stochastic Convex Optimization

TL;DR

This work studies the link between flat minima and generalization in the canonical setting of stochastic convex optimization with a non-negative, -smooth objective and finds that, even in this fundamental and well-studied setting, flat empirical minima may incur trivial population risk while sharp minima generalizes optimally.

Abstract

Understanding the generalization behavior of learning algorithms is a central goal of learning theory. A recently emerging explanation is that learning algorithms are successful in practice because they converge to flat minima, which have been consistently associated with improved generalization performance. In this work, we study the link between flat minima and generalization in the canonical setting of stochastic convex optimization with a non-negative, -smooth objective. Our first finding is that, even in this fundamental and well-studied setting, flat empirical minima may incur trivial population risk while sharp minima generalizes optimally. Then, we show that this poor generalization behavior extends to two natural ''sharpness-aware'' algorithms originally proposed by Foret et al. (2021), designed to bias optimization toward flat solutions: Sharpness-Aware Gradient Descent (SA-GD) and Sharpness-Aware Minimization (SAM). For SA-GD, which performs gradient steps on the maximal loss in a predefined neighborhood, we prove that while it successfully converges to a flat minimum at a fast rate, the population risk of the solution can still be as large as , indicating that even flat minima found algorithmically using a sharpness-aware gradient method might generalize poorly. For SAM, a computationally efficient approximation of SA-GD based on normalized ascent steps, we show that although it minimizes the empirical loss, it may converge to a sharp minimum and also incur population risk . Finally, we establish population risk upper bounds for both SA-GD and SAM using algorithmic stability techniques.

Paper Structure

This paper contains 35 sections, 16 theorems, 103 equations.

Key Result

theorem 1

For every $n \in \mathbb{N}$ and $0 \leq \rho \leq \tfrac{1}{2}$, let $d = 2^n + 1$ and define $W = \{x \in \mathbb{R}^d : \|x\| \leq 1\}$. Then there exist an instance set $\mathcal{Z}$, a distribution $\mathcal{D}$ over $\mathcal{Z}$, and a loss function $f : W \times \mathcal{Z} \to \mathbb{R}$ t

Theorems & Definitions (32)

  • Definition 1: $\boldsymbol{\rho}$-flatness
  • theorem 1
  • theorem 2
  • Lemma 1
  • theorem 3
  • theorem 4
  • theorem 5
  • theorem 6
  • theorem 7
  • theorem 8
  • ...and 22 more