Table of Contents
Fetching ...

Positive Distribution Shift as a Framework for Understanding Tractable Learning

Marko Medvedev, Idan Attias, Elisabetta Cornacchia, Theodor Misiakiewicz, Gal Vardi, Nathan Srebro

TL;DR

This work reframes distribution shift not as a hindrance but as a lever for tractable learning by proposing Positive Distribution Shift (PDS) and DS-PAC/f-PDS frameworks. It demonstrates that carefully chosen training distributions can render computationally hard classes, such as parity and junta functions, efficiently learnable with standard gradient-based methods, while preserving the target test distribution. The paper connects PDS to membership-query models, showing that DS-PAC implies NA-MQ and, in turn, RDSPAC, thereby linking practical training-data strategies to classical query-based frameworks. Collectively, these results provide a theoretical foundation for dataset design as a core component of learnability, with implications for when and how SGD-based training on neural nets can succeed under covariate shift.

Abstract

We study a setting where the goal is to learn a target function f(x) with respect to a target distribution D(x), but training is done on i.i.d. samples from a different training distribution D'(x), labeled by the true target f(x). Such a distribution shift (here in the form of covariate shift) is usually viewed negatively, as hurting or making learning harder, and the traditional distribution shift literature is mostly concerned with limiting or avoiding this negative effect. In contrast, we argue that with a well-chosen D'(x), the shift can be positive and make learning easier -- a perspective called Positive Distribution Shift (PDS). Such a perspective is central to contemporary machine learning, where much of the innovation is in finding good training distributions D'(x), rather than changing the training algorithm. We further argue that the benefit is often computational rather than statistical, and that PDS allows computationally hard problems to become tractable even using standard gradient-based training. We formalize different variants of PDS, show how certain hard classes are easily learnable under PDS, and make connections with membership query learning.

Positive Distribution Shift as a Framework for Understanding Tractable Learning

TL;DR

This work reframes distribution shift not as a hindrance but as a lever for tractable learning by proposing Positive Distribution Shift (PDS) and DS-PAC/f-PDS frameworks. It demonstrates that carefully chosen training distributions can render computationally hard classes, such as parity and junta functions, efficiently learnable with standard gradient-based methods, while preserving the target test distribution. The paper connects PDS to membership-query models, showing that DS-PAC implies NA-MQ and, in turn, RDSPAC, thereby linking practical training-data strategies to classical query-based frameworks. Collectively, these results provide a theoretical foundation for dataset design as a core component of learnability, with implications for when and how SGD-based training on neural nets can succeed under covariate shift.

Abstract

We study a setting where the goal is to learn a target function f(x) with respect to a target distribution D(x), but training is done on i.i.d. samples from a different training distribution D'(x), labeled by the true target f(x). Such a distribution shift (here in the form of covariate shift) is usually viewed negatively, as hurting or making learning harder, and the traditional distribution shift literature is mostly concerned with limiting or avoiding this negative effect. In contrast, we argue that with a well-chosen D'(x), the shift can be positive and make learning easier -- a perspective called Positive Distribution Shift (PDS). Such a perspective is central to contemporary machine learning, where much of the innovation is in finding good training distributions D'(x), rather than changing the training algorithm. We further argue that the benefit is often computational rather than statistical, and that PDS allows computationally hard problems to become tractable even using standard gradient-based training. We formalize different variants of PDS, show how certain hard classes are easily learnable under PDS, and make connections with membership query learning.
Paper Structure (48 sections, 19 theorems, 77 equations, 4 figures, 1 table, 3 algorithms)

This paper contains 48 sections, 19 theorems, 77 equations, 4 figures, 1 table, 3 algorithms.

Key Result

Theorem 3.2

For any $d,s \in \mathbb{N}$ there exists a fixed non-standard feed-forward neural networkthe network is non-standard in either its topology (i.e. not fully connected) and activation with standard initialization or it is fully connected with standard activation function but with special initializati

Figures (4)

  • Figure 1: Noisy Parities. We study learning a degree-$25$ parity over $50$ bits with label noise $\eta \in \{0,0.02,0.05\}$, using a two-layer $\mathrm{ReLU}$ network with $1024$ hidden units trained by SGD (batch size $64$, fresh samples, square loss, learning rate $0.01$, both layers trained jointly). In the PDS setting, training samples are drawn from ${\mathcal{D}}'=\tfrac{1}{2}{\mathcal{D}}_{0.96}+\tfrac{1}{2}{\mathcal{D}}_0$, where ${\mathcal{D}}_\mu={\rm Rad}((\mu+1)/2)^{\otimes d}$ for $\mu\in[-1,1]$; in the standard setting, from ${\mathcal{D}}'={\mathcal{D}}_0$. The left panel shows test error on ${\mathcal{D}}_0$ during training, where PDS yields markedly more efficient learning; dotted lines indicate Bayes error ($\eta$). The right panel compares ${\mathcal{D}}'={\mathcal{D}}_0$, ${\mathcal{D}}'=\tfrac{1}{2}{\mathcal{D}}_{0.96}+\tfrac{1}{2}{\mathcal{D}}_0$, and ${\mathcal{D}}'={\mathcal{D}}_{0.96}$, plotting test error on ${\mathcal{D}}_{0.96}$ (blue) and ${\mathcal{D}}_0$ (orange) after $10^6$ steps. Only the mixture distribution achieves PDS generalization to the target ${\mathcal{D}}_0$.
  • Figure 2: Noisy juntas. (Left) We consider learning $f_9$ (see Sec. \ref{['eq:junta_def_exp']}) over $d=50$ bits on ${\mathcal{D}} ={\rm Unif}\{ \pm 1 \}^d$ with a two-layer $\mathrm{ReLU}$ network ($1024$ hidden units) trained with SGD (batch size $64$, fresh samples, square loss, l.r. $0.01$, both layers trained jointly). In the PDS setting we train on ${\mathcal{D}}' = \frac{1}{2} {\mathcal{D}}_{{\boldsymbol \mu}} + \frac{1}{2} {\mathcal{D}}_{\boldsymbol{0}}$, with ${\boldsymbol \mu} \sim {\rm Unif}[-1,1]^{\otimes d}$ (and where ${\mathcal{D}}_{\boldsymbol \mu} := \otimes_{i \in [d]} {\rm Rad}((\mu_i+1)/2)$), while in the standard (no PDS) setting ${\mathcal{D}}'={\mathcal{D}}$. We plot the test error on ${\mathcal{D}}$ during training; the dotted lines show the Bayes error (i.e. $\eta$). (Right) We consider learning $f_7$ on ${\mathcal{D}}={\rm Unif}\{ \pm 1 \}^d$ and we plot the sample complexity needed to reach within $0.01$ of Bayes error versus the input dimension. In both plots, PDS training is markedly more efficient.
  • Figure 3: Sparse parity with noise. We compare PDS and standard (no-PDS) learning for a $5$-parity with label noise $\eta \in \{0,0.02,0.05\}$. (Left) For $d=50$, we plot the test accuracy on ${\mathcal{D}}$ versus gradient descent steps for a 4-layer ReLU network trained with SGD (batch size $b=64$) on fresh samples from ${\mathcal{D}}'$ (PDS) and from ${\mathcal{D}}$ (no PDS). Dotted lines show Bayes accuracy (i.e., $1-\eta$). (Right) We plot the sample complexity to reach within $0.01$ of Bayes error versus input dimension. We report the simulations that converged within $10^6$ training steps. In both figures, we see that PDS training is markedly more efficient.
  • Figure 4: Sparse juntas with noise. We compare PDS and standard (no-PDS) learning for $f_9$ (see \ref{['eq:junta_def_exp']}) with label noise $\eta \in \{0,0.02,0.05\}$. (Left) For $d=50$, we plot the test accuracy on ${\mathcal{D}}$ versus gradient descent steps for a 4-layer ReLU network trained with SGD (batch size $b=64$) on fresh samples from ${\mathcal{D}}'$ (PDS) and from ${\mathcal{D}}$ (no PDS). Dotted lines show Bayes accuracy (i.e., $1-\eta$). (Right) For $f_7$, we plot the sample complexity to reach within $0.01$ of Bayes accuracy versus input dimension. In both figures, we see that PDS training is markedly more efficient.

Theorems & Definitions (37)

  • Definition 1.1: PDS Learning
  • Definition 3.1: f-Dependent Positive Distribution Shift (f-PDS)
  • Theorem 3.2: Any Poly-sized Circuit with Label Noise is f-PDS Learnable with a Non-Standard Network
  • Definition 4.1: Deterministic Distribution-Shift PAC (D-DS-PAC)
  • Definition 4.2: Randomized Distribution Shift PAC (R-DS-PAC)
  • Theorem 4.3: Noisy Parities are Tractably D-DS-PAC Learnable
  • Theorem 4.5: Noisy Parities are D-DS-PAC Learnable with Stylized Analyzable GD
  • Theorem 4.6: Noisy Juntas are D-DS-PAC Learnable
  • Theorem 4.8: Noisy Juntas are R-DS-PAC Learnable with Stylized Analyzable GD
  • Definition 5.1: Non-Adaptive Membership Queries (NA-MQ)
  • ...and 27 more