Table of Contents
Fetching ...

CHAIN: Enhancing Generalization in Data-Efficient GANs via lipsCHitz continuity constrAIned Normalization

Yao Ni, Piotr Koniusz

TL;DR

Data-scarce GANs suffer from discriminator overfitting and training instability. CHAIN reimagines Batch Normalization by replacing the centering step with zero-mean regularization and enforcing a Lipschitz constraint on scaling through Adaptive Root Mean Square normalization, paired with adaptive interpolation between normalized and unnormalized features. The approach is underpinned by a PAC-Bayesian, IPM-based generalization analysis that connects reduced gradient norms to improved generalization, and is validated by extensive experiments showing state-of-the-art results on CIFAR-10/100, ImageNet, and several low-shot and high-resolution few-shot datasets. CHAIN proves to be a simple, architecture-agnostic technique that stabilizes GAN training under limited data while delivering substantial performance gains, accompanied by public code.

Abstract

Generative Adversarial Networks (GANs) significantly advanced image generation but their performance heavily depends on abundant training data. In scenarios with limited data, GANs often struggle with discriminator overfitting and unstable training. Batch Normalization (BN), despite being known for enhancing generalization and training stability, has rarely been used in the discriminator of Data-Efficient GANs. Our work addresses this gap by identifying a critical flaw in BN: the tendency for gradient explosion during the centering and scaling steps. To tackle this issue, we present CHAIN (lipsCHitz continuity constrAIned Normalization), which replaces the conventional centering step with zero-mean regularization and integrates a Lipschitz continuity constraint in the scaling step. CHAIN further enhances GAN training by adaptively interpolating the normalized and unnormalized features, effectively avoiding discriminator overfitting. Our theoretical analyses firmly establishes CHAIN's effectiveness in reducing gradients in latent features and weights, improving stability and generalization in GAN training. Empirical evidence supports our theory. CHAIN achieves state-of-the-art results in data-limited scenarios on CIFAR-10/100, ImageNet, five low-shot and seven high-resolution few-shot image datasets. Code: https://github.com/MaxwellYaoNi/CHAIN

CHAIN: Enhancing Generalization in Data-Efficient GANs via lipsCHitz continuity constrAIned Normalization

TL;DR

Data-scarce GANs suffer from discriminator overfitting and training instability. CHAIN reimagines Batch Normalization by replacing the centering step with zero-mean regularization and enforcing a Lipschitz constraint on scaling through Adaptive Root Mean Square normalization, paired with adaptive interpolation between normalized and unnormalized features. The approach is underpinned by a PAC-Bayesian, IPM-based generalization analysis that connects reduced gradient norms to improved generalization, and is validated by extensive experiments showing state-of-the-art results on CIFAR-10/100, ImageNet, and several low-shot and high-resolution few-shot datasets. CHAIN proves to be a simple, architecture-agnostic technique that stabilizes GAN training under limited data while delivering substantial performance gains, accompanied by public code.

Abstract

Generative Adversarial Networks (GANs) significantly advanced image generation but their performance heavily depends on abundant training data. In scenarios with limited data, GANs often struggle with discriminator overfitting and unstable training. Batch Normalization (BN), despite being known for enhancing generalization and training stability, has rarely been used in the discriminator of Data-Efficient GANs. Our work addresses this gap by identifying a critical flaw in BN: the tendency for gradient explosion during the centering and scaling steps. To tackle this issue, we present CHAIN (lipsCHitz continuity constrAIned Normalization), which replaces the conventional centering step with zero-mean regularization and integrates a Lipschitz continuity constraint in the scaling step. CHAIN further enhances GAN training by adaptively interpolating the normalized and unnormalized features, effectively avoiding discriminator overfitting. Our theoretical analyses firmly establishes CHAIN's effectiveness in reducing gradients in latent features and weights, improving stability and generalization in GAN training. Empirical evidence supports our theory. CHAIN achieves state-of-the-art results in data-limited scenarios on CIFAR-10/100, ImageNet, five low-shot and seven high-resolution few-shot image datasets. Code: https://github.com/MaxwellYaoNi/CHAIN
Paper Structure (31 sections, 7 theorems, 61 equations, 18 figures, 8 tables, 1 algorithm)

This paper contains 31 sections, 7 theorems, 61 equations, 18 figures, 8 tables, 1 algorithm.

Key Result

Lemma 3.1

(Partial results of Theorem 1 in ji2021understanding.) Assume the discriminator set $\mathcal{H}$ is even, i.e., $h\!\in\!\mathcal{H}$ implies $-h\!\in\!\mathcal{H}$, and $\lVert h\rVert_\infty\leq\!\Delta$. Let $\hat{\mu}_n$ and $\hat{\nu}_n$ be empirical measures of $\mu$ and $\nu_n$ with size $n$

Figures (18)

  • Figure 1: Motivation of using BN, discriminator with CHAIN, modules in CHAIN and the Pytorch-style pseudo-code for CHAIN$_\text{batch}$.
  • Figure 2: (a) Mean cosine similarity of discriminator pre-activation features, and (b) gradient norm of the feature extractor w.r.t. the input are evaluated for OmniGAN, OmniGAN+0C (using the centering step in Eq. \ref{['eq:centering']}), and OmniGAN+A0C (adaptive interpolation between centered and uncentered features). Evaluation conducted on 10% CIFAR-10 data with OmniGAN ($d=256$).
  • Figure 3: (a) Gradient norm of discriminator output w.r.t. input during training, and (b) effective rank roy2007effective of the pre-activation features in discriminator, are evaluated on 10% CIFAR-10 data with OmniGAN ($d\!\!=\!\!256$). CHAIN$_{+0C}$: CHAIN w/ the centering step. CHAIN$_{-LC}$: CHAIN w/o the Lipschitzness constraint.
  • Figure 4: The discriminator output w.r.t. real, fake and test images using (a) OmniGAN, (b) OmniGAN+CHAIN, and (c) the gradient norm of the discriminator output w.r.t. discriminator weights on 10% CIFAR-10 using OmniGAN ($d=256$). Note the $y$-axis in (b) is scaled for clearer visualization.
  • Figure 5: The discriminator output w.r.t. real, fake and test images of (a) BigGAN, (b) BigGAN+CHAIN, along with (c) the gradient norm of the discriminator output w.r.t. discriminator weights on 10% CIFAR-100 with BigGAN ($d=256$).
  • ...and 13 more figures

Theorems & Definitions (7)

  • Lemma 3.1
  • Proposition 3.1
  • Theorem 3.1
  • Theorem 3.2
  • Theorem 3.3
  • Lemma B.1
  • Theorem C.1