Table of Contents
Fetching ...

Deep k-Nearest Neighbors: Towards Confident, Interpretable and Robust Deep Learning

Nicolas Papernot, Patrick McDaniel

TL;DR

This paper introduces Deep k-Nearest Neighbors (DkNN), a hybrid approach that leverages k-NN searches over layer-wise DNN representations and applies inductive conformal prediction to generate calibrated confidence (credibility) and human-interpretable explanations via training exemplars. By enforcing cross-layer conformity to the training data, DkNN yields well-calibrated confidence, transparent explanations, and enhanced robustness to inputs outside the training manifold, including adversarial examples. The authors validate DkNN on MNIST, SVHN, and GTSRB, showing improved credibility calibration for out-of-distribution inputs and meaningful interpretability through exemplars, as well as improved detection of adversarial inputs and insight into mispredictions. The work demonstrates that integrating simple, layer-wise validation of internal representations can significantly improve trust and security in deep learning systems, and points to further research in adaptive attacks and broader application domains.

Abstract

Deep neural networks (DNNs) enable innovative applications of machine learning like image recognition, machine translation, or malware detection. However, deep learning is often criticized for its lack of robustness in adversarial settings (e.g., vulnerability to adversarial inputs) and general inability to rationalize its predictions. In this work, we exploit the structure of deep learning to enable new learning-based inference and decision strategies that achieve desirable properties such as robustness and interpretability. We take a first step in this direction and introduce the Deep k-Nearest Neighbors (DkNN). This hybrid classifier combines the k-nearest neighbors algorithm with representations of the data learned by each layer of the DNN: a test input is compared to its neighboring training points according to the distance that separates them in the representations. We show the labels of these neighboring points afford confidence estimates for inputs outside the model's training manifold, including on malicious inputs like adversarial examples--and therein provides protections against inputs that are outside the models understanding. This is because the nearest neighbors can be used to estimate the nonconformity of, i.e., the lack of support for, a prediction in the training data. The neighbors also constitute human-interpretable explanations of predictions. We evaluate the DkNN algorithm on several datasets, and show the confidence estimates accurately identify inputs outside the model, and that the explanations provided by nearest neighbors are intuitive and useful in understanding model failures.

Deep k-Nearest Neighbors: Towards Confident, Interpretable and Robust Deep Learning

TL;DR

This paper introduces Deep k-Nearest Neighbors (DkNN), a hybrid approach that leverages k-NN searches over layer-wise DNN representations and applies inductive conformal prediction to generate calibrated confidence (credibility) and human-interpretable explanations via training exemplars. By enforcing cross-layer conformity to the training data, DkNN yields well-calibrated confidence, transparent explanations, and enhanced robustness to inputs outside the training manifold, including adversarial examples. The authors validate DkNN on MNIST, SVHN, and GTSRB, showing improved credibility calibration for out-of-distribution inputs and meaningful interpretability through exemplars, as well as improved detection of adversarial inputs and insight into mispredictions. The work demonstrates that integrating simple, layer-wise validation of internal representations can significantly improve trust and security in deep learning systems, and points to further research in adaptive attacks and broader application domains.

Abstract

Deep neural networks (DNNs) enable innovative applications of machine learning like image recognition, machine translation, or malware detection. However, deep learning is often criticized for its lack of robustness in adversarial settings (e.g., vulnerability to adversarial inputs) and general inability to rationalize its predictions. In this work, we exploit the structure of deep learning to enable new learning-based inference and decision strategies that achieve desirable properties such as robustness and interpretability. We take a first step in this direction and introduce the Deep k-Nearest Neighbors (DkNN). This hybrid classifier combines the k-nearest neighbors algorithm with representations of the data learned by each layer of the DNN: a test input is compared to its neighboring training points according to the distance that separates them in the representations. We show the labels of these neighboring points afford confidence estimates for inputs outside the model's training manifold, including on malicious inputs like adversarial examples--and therein provides protections against inputs that are outside the models understanding. This is because the nearest neighbors can be used to estimate the nonconformity of, i.e., the lack of support for, a prediction in the training data. The neighbors also constitute human-interpretable explanations of predictions. We evaluate the DkNN algorithm on several datasets, and show the confidence estimates accurately identify inputs outside the model, and that the explanations provided by nearest neighbors are intuitive and useful in understanding model failures.

Paper Structure

This paper contains 31 sections, 5 equations, 13 figures, 3 tables, 1 algorithm.

Figures (13)

  • Figure 1: Intuition behind the Deep k-Nearest Neighbors (DkNN)---Consider a deep neural network (left), representations output by each layer (middle) and the nearest neighbors found at each layer in the training data (right). Drawings of pandas and school buses indicate training points. Confidence is high when there is homogeneity among the nearest neighbors labels (e.g., here for the unmodified panda image). Interpretability of the outcome of each layer is provided by the nearest neighbors. Robustness stems from detecting nonconformal predictions from nearest neighbor labels found for out-of-distribution inputs (e.g., an adversarial panda) across different layers. Representation spaces are high-dimensional but depicted in 2D for clarity.
  • Figure 2: Reliability diagrams of DNN softmax confidence (left) and DkNN credibility (right) on test data---bars (left axis) indicate the mean accuracy of predictions binned by credibility; the red line (right axis) illustrates data density across bins. The softmax outputs high confidence on most of the data while DkNN credibility spreads across the value range.
  • Figure 3: Mislabeled inputs from the MNIST (top) and SVHN (bottom) test sets: we found these points by searching for inputs that are classified with strong credibility by the DkNN in a class that is different than the label found in the dataset.
  • Figure 4: DkNN credibility vs. softmax confidence on out-of-distribution test data: the lower credibility of DkNN predictions (solid lines) compared to the softmax confidence (dotted lines) is desirable here because test inputs are not part of the distribution on which the model was trained---they are from another dataset or created by rotating inputs.
  • Figure 5: Debugging ResNet model biases---This illustrates how the DkNN algorithm helps to understand a bias identified by Stock and Cissé stock2017convnets in the ResNet model for ImageNet. The image at the bottom of each column is the test input presented to the DkNN. Each test input is cropped slightly differently to include (left) or exclude (right) the football. Images shown at the top are nearest neighbors in the predicted class according to the representation output by the last hidden layer. This comparison suggests that the "basketball" prediction may have been a consequence of the ball being in the picture. Also note how the white apparel color and general arm positions of players often match the test image of Barack Obama.
  • ...and 8 more figures