Table of Contents
Fetching ...

Deep Neural Networks Tend To Extrapolate Predictably

Katie Kang, Amrith Setlur, Claire Tomlin, Sergey Levine

TL;DR

This work revisits how neural networks extrapolate under out-of-distribution conditions, revealing a robust tendency to revert toward an optimal constant solution (OCS) that minimizes training loss without depending on inputs. It combines extensive experiments across vision and language, diverse losses, and architectures with empirical and theoretical analyses (including deep homogeneous ReLU networks) to explain why OOD representations shrink and outputs become input-agnostic. The authors demonstrate that this reversion can be harnessed for risk-sensitive decision-making, notably in selective classification, by aligning the OCS with desired cautious behavior. While the phenomenon is pervasive, the paper also discusses limitations (e.g., adversarial cases) and outlines directions for further research and practical safeguards.

Abstract

Conventional wisdom suggests that neural network predictions tend to be unpredictable and overconfident when faced with out-of-distribution (OOD) inputs. Our work reassesses this assumption for neural networks with high-dimensional inputs. Rather than extrapolating in arbitrary ways, we observe that neural network predictions often tend towards a constant value as input data becomes increasingly OOD. Moreover, we find that this value often closely approximates the optimal constant solution (OCS), i.e., the prediction that minimizes the average loss over the training data without observing the input. We present results showing this phenomenon across 8 datasets with different distributional shifts (including CIFAR10-C and ImageNet-R, S), different loss functions (cross entropy, MSE, and Gaussian NLL), and different architectures (CNNs and transformers). Furthermore, we present an explanation for this behavior, which we first validate empirically and then study theoretically in a simplified setting involving deep homogeneous networks with ReLU activations. Finally, we show how one can leverage our insights in practice to enable risk-sensitive decision-making in the presence of OOD inputs.

Deep Neural Networks Tend To Extrapolate Predictably

TL;DR

This work revisits how neural networks extrapolate under out-of-distribution conditions, revealing a robust tendency to revert toward an optimal constant solution (OCS) that minimizes training loss without depending on inputs. It combines extensive experiments across vision and language, diverse losses, and architectures with empirical and theoretical analyses (including deep homogeneous ReLU networks) to explain why OOD representations shrink and outputs become input-agnostic. The authors demonstrate that this reversion can be harnessed for risk-sensitive decision-making, notably in selective classification, by aligning the OCS with desired cautious behavior. While the phenomenon is pervasive, the paper also discusses limitations (e.g., adversarial cases) and outlines directions for further research and practical safeguards.

Abstract

Conventional wisdom suggests that neural network predictions tend to be unpredictable and overconfident when faced with out-of-distribution (OOD) inputs. Our work reassesses this assumption for neural networks with high-dimensional inputs. Rather than extrapolating in arbitrary ways, we observe that neural network predictions often tend towards a constant value as input data becomes increasingly OOD. Moreover, we find that this value often closely approximates the optimal constant solution (OCS), i.e., the prediction that minimizes the average loss over the training data without observing the input. We present results showing this phenomenon across 8 datasets with different distributional shifts (including CIFAR10-C and ImageNet-R, S), different loss functions (cross entropy, MSE, and Gaussian NLL), and different architectures (CNNs and transformers). Furthermore, we present an explanation for this behavior, which we first validate empirically and then study theoretically in a simplified setting involving deep homogeneous networks with ReLU activations. Finally, we show how one can leverage our insights in practice to enable risk-sensitive decision-making in the presence of OOD inputs.
Paper Structure (36 sections, 9 theorems, 48 equations, 16 figures, 1 table)

This paper contains 36 sections, 9 theorems, 48 equations, 16 figures, 1 table.

Key Result

Proposition 4.1

When $f(\hat{W}; x)$ fits $\mathcal{D}$, i.e., $y_i f(\hat{W}; x_i)$$\geq$$\gamma$, $\forall$$i$$\in$$[N]$, then w.h.p $1-\delta$ over $\mathcal{D}$, layer $j$ representations $f_j(\hat{W}; x)$ satisfy $\mathbb{E}_{P_\mathrm{train}}[\|f_j(\hat{W}; x)\|_2]\geq (1/C_0) (\gamma - \mathcal{\tilde{O}}(\s

Figures (16)

  • Figure 1: A summary of our observations. On in-distribution samples (top), neural network outputs tend to vary significantly based on input labels. In contrast, on OOD samples (bottom), we observe that model predictions tend to not only be more similar to one another, but also gravitate towards the optimal constant solution (OCS). We also observe that OOD inputs tend to map to representations with smaller magnitudes, leading to predictions largely dominated by the (constant) network biases, which may shed light on why neural networks have this tendency.
  • Figure 2: Neural network predictions from training with cross entropy and Gaussian NLL on MNIST (top 3 rows) and CIFAR10 (bottom 3 rows). The models were trained with 0 rotation/noise, and evaluated on increasingly OOD inputs consisting of the digit 6 for MNIST, and of automobiles for CIFAR10. The blue plots represent the average model prediction over the evaluation dataset. The orange plots show the OCS associated with each model. We can see that as the distribution shift increases (going left to right), the network predictions tend towards the OCS (rightmost column).
  • Figure 3: Evaluating the distance between network predictions and the OCS as the input distribution becomes more OOD. Each point represents a different evaluation dataset, with the red star representing the (holdout) training distribution, and circles representing OOD datasets. The vertical line associated with each point represents the standard deviation over 5 training runs. As the OOD score of the evaluation dataset increases, there is a clear trend of the neural network predictions approaching the OCS.
  • Figure 4: Analysis of the interaction between representations and weights as distribution shift increases. Plots in first column visualize the norm of network features for different levels of distribution shift at different layers of the network. In later layer of the network, the norm of features tends to decrease as distribution shift increases. Plots in second column show the proportion of network features which lie within the span of the following linear layer. This tends to decrease as distributional shift increases. Error bars represent the standard deviation taken over the test distribution. Plots in the third and fourth column show the accumulation of model constants as compared to the OCS for a cross entropy and a MSE model; the two closely mirror one another.
  • Figure 5: Selective classification via reward prediction on CIFAR10. We evaluate on holdout datasets consisting of automobiles (class 1) with increasing levels of noise. X-axis represents the agent's actions, where classes are indexed by numbers and abstain is represented by "A". We plot the average reward predicted by the model for each class (top), and the distribution of actions selected by the policy (bottom). The rightmost plots represent the OCS (top), and the actions selected by an OCS policy (bottom). As distribution shift increased, the model predictions approached the OCS, and the policy automatically selected the abstain action more frequently.
  • ...and 11 more figures

Theorems & Definitions (9)

  • Proposition 4.1: $P_\mathrm{train}$ observes high norm features
  • Theorem 4.1: Feature norms can drop easily on $P_\mathrm{OOD}$
  • Proposition 4.2: Analyzing network bias
  • Lemma D.1: Gradient flow is implicitly biased towards minimum $\|\cdot\|_2$
  • Lemma D.2: Adaptation of Theorem 1.1 from bartlett2017spectrally
  • Proposition D.1: $P_\mathrm{train}$ observes high norm features
  • Theorem D.3: Feature norms can drop easily on $P_\mathrm{OOD}$
  • Lemma D.4: GF on deep and wide nets is learns low rank $W_1, \ldots, W_L$
  • Proposition D.2: Analyzing network bias