Explaining and Improving Model Behavior with k Nearest Neighbor Representations

by   Nazneen Fatema Rajani, et al.

Interpretability techniques in NLP have mainly focused on understanding individual predictions using attention visualization or gradient-based saliency maps over tokens. We propose using k nearest neighbor (kNN) representations to identify training examples responsible for a model's predictions and obtain a corpus-level understanding of the model's behavior. Apart from interpretability, we show that kNN representations are effective at uncovering learned spurious associations, identifying mislabeled examples, and improving the fine-tuned model's performance. We focus on Natural Language Inference (NLI) as a case study and experiment with multiple datasets. Our method deploys backoff to kNN for BERT and RoBERTa on examples with low model confidence without any update to the model parameters. Our results indicate that the kNN approach makes the finetuned model more robust to adversarial inputs.


page 1

page 2

page 3

page 4


Nearest Neighbor Machine Translation

We introduce k-nearest-neighbor machine translation (kNN-MT), which pred...

Situating Sentence Embedders with Nearest Neighbor Overlap

As distributed approaches to natural language semantics have developed a...

Explaining Black Box Predictions and Unveiling Data Artifacts through Influence Functions

Modern deep learning models for NLP are notoriously opaque. This has mot...

Discriminative Nearest Neighbor Few-Shot Intent Detection by Transferring Natural Language Inference

Intent detection is one of the core components of goal-oriented dialog s...

Deep k-Nearest Neighbors: Towards Confident, Interpretable and Robust Deep Learning

Deep neural networks (DNNs) enable innovative applications of machine le...

Learning with Imprinted Weights

Human vision is able to immediately recognize novel visual categories af...

Recoding latent sentence representations – Dynamic gradient-based activation modification in RNNs

In Recurrent Neural Networks (RNNs), encoding information in a suboptima...