Table of Contents
Fetching ...

Topological Generalization Bounds for Discrete-Time Stochastic Optimization Algorithms

Rayna Andreeva, Benjamin Dupuis, Rik Sarkar, Tolga Birdal, Umut Şimşekli

TL;DR

The paper develops a discrete-time, topology-based framework to bound the generalization error of stochastic optimization algorithms by introducing alpha-weighted lifetime sums ($\boldsymbol{E}_\alpha$) and positive magnitude ($\mathrm{\mathbf{PMag}}$). These topological complexities, together with a total mutual information term $\mathrm{I}_\infty$, yield non-asymptotic generalization bounds that apply to practical, finite training trajectories. The authors provide computable, data-driven procedures to estimate these complexities and demonstrate strong correlations with generalization across vision transformers and graph neural networks, often outperforming prior persistent-homology bounds. The work offers a scalable, domain-agnostic toolkit for assessing generalization risk in discrete-time DNN training and sets the stage for broader applications and refinements in topology-guided learning theory.

Abstract

We present a novel set of rigorous and computationally efficient topology-based complexity notions that exhibit a strong correlation with the generalization gap in modern deep neural networks (DNNs). DNNs show remarkable generalization properties, yet the source of these capabilities remains elusive, defying the established statistical learning theory. Recent studies have revealed that properties of training trajectories can be indicative of generalization. Building on this insight, state-of-the-art methods have leveraged the topology of these trajectories, particularly their fractal dimension, to quantify generalization. Most existing works compute this quantity by assuming continuous- or infinite-time training dynamics, complicating the development of practical estimators capable of accurately predicting generalization without access to test data. In this paper, we respect the discrete-time nature of training trajectories and investigate the underlying topological quantities that can be amenable to topological data analysis tools. This leads to a new family of reliable topological complexity measures that provably bound the generalization error, eliminating the need for restrictive geometric assumptions. These measures are computationally friendly, enabling us to propose simple yet effective algorithms for computing generalization indices. Moreover, our flexible framework can be extended to different domains, tasks, and architectures. Our experimental results demonstrate that our new complexity measures correlate highly with generalization error in industry-standards architectures such as transformers and deep graph networks. Our approach consistently outperforms existing topological bounds across a wide range of datasets, models, and optimizers, highlighting the practical relevance and effectiveness of our complexity measures.

Topological Generalization Bounds for Discrete-Time Stochastic Optimization Algorithms

TL;DR

The paper develops a discrete-time, topology-based framework to bound the generalization error of stochastic optimization algorithms by introducing alpha-weighted lifetime sums () and positive magnitude (). These topological complexities, together with a total mutual information term , yield non-asymptotic generalization bounds that apply to practical, finite training trajectories. The authors provide computable, data-driven procedures to estimate these complexities and demonstrate strong correlations with generalization across vision transformers and graph neural networks, often outperforming prior persistent-homology bounds. The work offers a scalable, domain-agnostic toolkit for assessing generalization risk in discrete-time DNN training and sets the stage for broader applications and refinements in topology-guided learning theory.

Abstract

We present a novel set of rigorous and computationally efficient topology-based complexity notions that exhibit a strong correlation with the generalization gap in modern deep neural networks (DNNs). DNNs show remarkable generalization properties, yet the source of these capabilities remains elusive, defying the established statistical learning theory. Recent studies have revealed that properties of training trajectories can be indicative of generalization. Building on this insight, state-of-the-art methods have leveraged the topology of these trajectories, particularly their fractal dimension, to quantify generalization. Most existing works compute this quantity by assuming continuous- or infinite-time training dynamics, complicating the development of practical estimators capable of accurately predicting generalization without access to test data. In this paper, we respect the discrete-time nature of training trajectories and investigate the underlying topological quantities that can be amenable to topological data analysis tools. This leads to a new family of reliable topological complexity measures that provably bound the generalization error, eliminating the need for restrictive geometric assumptions. These measures are computationally friendly, enabling us to propose simple yet effective algorithms for computing generalization indices. Moreover, our flexible framework can be extended to different domains, tasks, and architectures. Our experimental results demonstrate that our new complexity measures correlate highly with generalization error in industry-standards architectures such as transformers and deep graph networks. Our approach consistently outperforms existing topological bounds across a wide range of datasets, models, and optimizers, highlighting the practical relevance and effectiveness of our complexity measures.
Paper Structure (61 sections, 17 theorems, 90 equations, 35 figures, 7 tables)

This paper contains 61 sections, 17 theorems, 90 equations, 35 figures, 7 tables.

Key Result

Theorem 3.4

Let $\rho$ be a pseudometric on ${\mathds{R}^d}$. Supposes that Assumption ass:boundedness holds and that $\ell$ is $(q,L, \rho)$-Lipschitz, for $q \geq 1$. Then, for all $\alpha \in [0, 1]$, with probability at least $1 - \zeta$, we have: with $K_{n,\alpha} := 2 \left( 2 L \sqrt{n} / B \right)^\alpha$.

Figures (35)

  • Figure 1: We devise a novel class of complexity measures that capture the topological properties of discrete training trajectories. These generalization bounds correlate highly with the test performance for a variety of deep networks, data domains and datasets. Figure shows different trajectories (a) embedded using multi-dimensional scaling based on the distance-matrices (b) computed using either the Euclidean distance ($\|\cdot\|_2$) between weights as in birdal2021intrinsic or via the loss-induced pseudo-metric ($\rho_S$) as in dupuis2023generalization. (c) plots the average granulated Kendall coefficients for two of our generalization measures ($\boldsymbol{E}_{\alpha}$ and $\mathrm{\mathbf{PMag}}(\sqrt{n})$) in comparison to the state-of-the-art persistent homology dimensions birdal2021intrinsicdupuis2023generalization for a range of models, datasets, and domains, revealing significant gains and practical relevance.
  • Figure 1: Correlation coefficients associated with the different topological complexities.
  • Figure 2: Left: Comparison of $\mathrm{\mathbf{Mag}}$ and $\mathrm{\mathbf{PMag}}$ (for $s = \sqrt{n}$), for different (pseudo)metrics (ViT on CIFAR$10$). Right: relative variation of the quantities $\boldsymbol{E}_\alpha(\mathcal{W}_{t_0 \to T})$ and $\mathrm{\mathbf{Mag}}(\sqrt{n}\mathcal{W}_{t_0 \to T})$, with respect to the proportion of the data used to estimated $\rho_S^{(1)}$ (ViT on CIFAR$10$).
  • Figure 3: $\rho_S$-based complexity measures vs. generalization gap for a ViT trained on CIFAR$10$: $\dim_{\mathrm{PH}}^{}$ (left), $\mathrm{\mathbf{PMag}}(\sqrt{n})$ (middle), and $\boldsymbol{E}_1$ (right).
  • Figure 4: Granulated Kendall coefficients for several models, datasets and topological quantities. Note that our framework is directly applicable to graph networks.
  • ...and 30 more figures

Theorems & Definitions (63)

  • Definition 3.1: $(q,L,\rho)$-Lipschitz continuity
  • Example 3.2: Data-dependent pseudometrics
  • Example 3.3: Euclidean distance
  • Theorem 3.4
  • Theorem 3.5
  • proof
  • Remark 3.6
  • Definition A.1: Total mutual information
  • Definition A.2: Rademacher complexity on a hypothesis set
  • Definition A.3
  • ...and 53 more