Table of Contents
Fetching ...

DC4L: Distribution Shift Recovery via Data-Driven Control for Deep Learning Models

Vivian Lin, Kuk Jin Jang, Souradeep Dutta, Michele Caprio, Oleg Sokolsky, Insup Lee

TL;DR

This work tackles the problem of deep model brittleness under real-world distribution shifts by proposing DC4L and its online supervisor SuperStAR, which sanitize inputs through a learned sequence of semantic-preserving transforms. The method formulates shift recovery as a Markov decision process and learns a policy via reinforcement learning, guided by a Wasserstein-distance-based reward and an operability classifier to ensure safe operation. Dimensionality reduction through orthonormal (Cai-Lim) projections enables efficient Wasserstein distance estimation, supporting online decision-making. Applied to ImageNet-C and CIFAR-100-C, SuperStAR yields substantial accuracy gains across multiple shifts and generalizes to composite shifts without retraining the policy. The approach offers a practical, online augmentation-free mechanism to bolster robustness in real-world vision systems, with room for extending the action library and addressing speed constraints.

Abstract

Deep neural networks have repeatedly been shown to be non-robust to the uncertainties of the real world, even to naturally occurring ones. A vast majority of current approaches have focused on data-augmentation methods to expand the range of perturbations that the classifier is exposed to while training. A relatively unexplored avenue that is equally promising involves sanitizing an image as a preprocessing step, depending on the nature of perturbation. In this paper, we propose to use control for learned models to recover from distribution shifts online. Specifically, our method applies a sequence of semantic-preserving transformations to bring the shifted data closer in distribution to the training set, as measured by the Wasserstein distance. Our approach is to 1) formulate the problem of distribution shift recovery as a Markov decision process, which we solve using reinforcement learning, 2) identify a minimum condition on the data for our method to be applied, which we check online using a binary classifier, and 3) employ dimensionality reduction through orthonormal projection to aid in our estimates of the Wasserstein distance. We provide theoretical evidence that orthonormal projection preserves characteristics of the data at the distributional level. We apply our distribution shift recovery approach to the ImageNet-C benchmark for distribution shifts, demonstrating an improvement in average accuracy of up to 14.21% across a variety of state-of-the-art ImageNet classifiers. We further show that our method generalizes to composites of shifts from the ImageNet-C benchmark, achieving improvements in average accuracy of up to 9.81%. Finally, we test our method on CIFAR-100-C and report improvements of up to 8.25%.

DC4L: Distribution Shift Recovery via Data-Driven Control for Deep Learning Models

TL;DR

This work tackles the problem of deep model brittleness under real-world distribution shifts by proposing DC4L and its online supervisor SuperStAR, which sanitize inputs through a learned sequence of semantic-preserving transforms. The method formulates shift recovery as a Markov decision process and learns a policy via reinforcement learning, guided by a Wasserstein-distance-based reward and an operability classifier to ensure safe operation. Dimensionality reduction through orthonormal (Cai-Lim) projections enables efficient Wasserstein distance estimation, supporting online decision-making. Applied to ImageNet-C and CIFAR-100-C, SuperStAR yields substantial accuracy gains across multiple shifts and generalizes to composite shifts without retraining the policy. The approach offers a practical, online augmentation-free mechanism to bolster robustness in real-world vision systems, with room for extending the action library and addressing speed constraints.

Abstract

Deep neural networks have repeatedly been shown to be non-robust to the uncertainties of the real world, even to naturally occurring ones. A vast majority of current approaches have focused on data-augmentation methods to expand the range of perturbations that the classifier is exposed to while training. A relatively unexplored avenue that is equally promising involves sanitizing an image as a preprocessing step, depending on the nature of perturbation. In this paper, we propose to use control for learned models to recover from distribution shifts online. Specifically, our method applies a sequence of semantic-preserving transformations to bring the shifted data closer in distribution to the training set, as measured by the Wasserstein distance. Our approach is to 1) formulate the problem of distribution shift recovery as a Markov decision process, which we solve using reinforcement learning, 2) identify a minimum condition on the data for our method to be applied, which we check online using a binary classifier, and 3) employ dimensionality reduction through orthonormal projection to aid in our estimates of the Wasserstein distance. We provide theoretical evidence that orthonormal projection preserves characteristics of the data at the distributional level. We apply our distribution shift recovery approach to the ImageNet-C benchmark for distribution shifts, demonstrating an improvement in average accuracy of up to 14.21% across a variety of state-of-the-art ImageNet classifiers. We further show that our method generalizes to composites of shifts from the ImageNet-C benchmark, achieving improvements in average accuracy of up to 9.81%. Finally, we test our method on CIFAR-100-C and report improvements of up to 8.25%.
Paper Structure (26 sections, 3 theorems, 17 equations, 14 figures, 7 tables, 1 algorithm)

This paper contains 26 sections, 3 theorems, 17 equations, 14 figures, 7 tables, 1 algorithm.

Key Result

theorem 1

$R(\mathcal{I}_k) \leq \alpha \cdot d_{TV}({D}, {D}_{\mathcal{I}_k \circ \mathbb{T}})$, for some finite $\alpha \in \mathbb{R}$ and semantic preserving transform $\mathcal{I}_k$.

Figures (14)

  • Figure 1: Overview of SuperStAR. At deployment, assume that a distribution shift causes a drop in accuracy. This is detected through changes in the Wasserstein distance between a validation set and the corrupted set. SuperStAR computes a composition of transforms $\mathcal{I}_k$ to adapt to the shift and recover accuracy. This composition of SuperStAR with the classifier helps it detect and adapt, boosting robustness of classification.
  • Figure 2: Example transformations applied to an image with contrast shift level 5. CLAHE($x$,$y$) denotes histogram equalization with strength determined by $x$ and $y$ (details in Appendix \ref{['sec:experiments']}). The policy applies a non-trivial composition of transformations that would be difficult to find through manual manipulation. The policy chooses few redundant actions and improves the accuracy of an AugMix-trained ResNet-50 on a random batch of 1000 images.
  • Figure 3: Operation of SuperStAR. Starting from $\mathbf{V}_c$, the algorithm selects a sequence of transforms $\mathcal{T}$ which move $\mathbf{V}_c$ closer to the original distribution $\mathbf{V}$. During the sequence, the orthonormal projections $\varphi({\mathbf{V}})$ and $\varphi({\mathbf{V}_c})$ are used to compute the Wasserstein distance $W_p(\varphi({\mathbf{V}_c}), \varphi(\mathbf{V}))$. See Section \ref{['sec:orthonormal_proj']} for details.
  • Figure 4: Empirical Wasserstein distance between MNIST and MNIST with varied levels of additive Gaussian noise, measured over a range of sample sizes. Curves are taken over 5 trials. MNIST samples are downsampled to 24$\times$24, flattened, and projected to 50 dimensions. Orthonormal projection better preserves distributional information than Gaussian random projection and sparse random projection.
  • Figure 5: Empirical Wasserstein distance between MNIST and MNIST with varied levels of additive impulse noise, measured over a range of sample sizes. Curves are taken over 5 trials. MNIST samples are downsampled to 24$\times$24, flattened, and projected to 50 dimensions. Orthonormal projection better preserves distributional information than Gaussian random projection and sparse random projection.
  • ...and 9 more figures

Theorems & Definitions (8)

  • definition 1: Semantic Preserving Transform
  • theorem 1
  • proof
  • definition 2: MDP
  • definition 3: Operable Set
  • lemma 1
  • corollary 1
  • proof