Table of Contents
Fetching ...

VI-OOD: A Unified Representation Learning Framework for Textual Out-of-distribution Detection

Li-Ming Zhan, Bo Liu, Xiao-Ming Wu

TL;DR

This work addresses the challenge of detecting out-of-distribution (OOD) inputs in NLP by noting that training to maximize the conditional likelihood $p(y|x)$ biases learned representations toward in-distribution (ID) tasks. It introduces VI-OOD, a variational inference framework that maximizes the joint distribution $p(x,y)$ via an amortized posterior $q(z|x)$ and a latent variable $Z$, enabling richer latent representations for OOD discrimination. The method instantiates a Transformer-based encoder to obtain $q(z|x)$, a lightweight decoder to reconstruct a target representation $x^{\text{target}}$ from $Z$, and a discriminator $f_{ID}$ for ID classification, with a tailored reconstruction target that linearly combines intermediate hidden states using learnable weights $\mathbf{s}$. The approach yields improved textual OOD detection across both encoder- and decoder-based Transformer models, and the authors provide release of their code for reproducibility. Overall, VI-OOD offers a principled, scalable way to leverage the rich hierarchical representations of Transformers to improve safety-critical NLP applications in realistic, distribution-shift scenarios.

Abstract

Out-of-distribution (OOD) detection plays a crucial role in ensuring the safety and reliability of deep neural networks in various applications. While there has been a growing focus on OOD detection in visual data, the field of textual OOD detection has received less attention. Only a few attempts have been made to directly apply general OOD detection methods to natural language processing (NLP) tasks, without adequately considering the characteristics of textual data. In this paper, we delve into textual OOD detection with Transformers. We first identify a key problem prevalent in existing OOD detection methods: the biased representation learned through the maximization of the conditional likelihood $p(y\mid x)$ can potentially result in subpar performance. We then propose a novel variational inference framework for OOD detection (VI-OOD), which maximizes the likelihood of the joint distribution $p(x, y)$ instead of $p(y\mid x)$. VI-OOD is tailored for textual OOD detection by efficiently exploiting the representations of pre-trained Transformers. Through comprehensive experiments on various text classification tasks, VI-OOD demonstrates its effectiveness and wide applicability. Our code has been released at \url{https://github.com/liam0949/LLM-OOD}.

VI-OOD: A Unified Representation Learning Framework for Textual Out-of-distribution Detection

TL;DR

This work addresses the challenge of detecting out-of-distribution (OOD) inputs in NLP by noting that training to maximize the conditional likelihood biases learned representations toward in-distribution (ID) tasks. It introduces VI-OOD, a variational inference framework that maximizes the joint distribution via an amortized posterior and a latent variable , enabling richer latent representations for OOD discrimination. The method instantiates a Transformer-based encoder to obtain , a lightweight decoder to reconstruct a target representation from , and a discriminator for ID classification, with a tailored reconstruction target that linearly combines intermediate hidden states using learnable weights . The approach yields improved textual OOD detection across both encoder- and decoder-based Transformer models, and the authors provide release of their code for reproducibility. Overall, VI-OOD offers a principled, scalable way to leverage the rich hierarchical representations of Transformers to improve safety-critical NLP applications in realistic, distribution-shift scenarios.

Abstract

Out-of-distribution (OOD) detection plays a crucial role in ensuring the safety and reliability of deep neural networks in various applications. While there has been a growing focus on OOD detection in visual data, the field of textual OOD detection has received less attention. Only a few attempts have been made to directly apply general OOD detection methods to natural language processing (NLP) tasks, without adequately considering the characteristics of textual data. In this paper, we delve into textual OOD detection with Transformers. We first identify a key problem prevalent in existing OOD detection methods: the biased representation learned through the maximization of the conditional likelihood can potentially result in subpar performance. We then propose a novel variational inference framework for OOD detection (VI-OOD), which maximizes the likelihood of the joint distribution instead of . VI-OOD is tailored for textual OOD detection by efficiently exploiting the representations of pre-trained Transformers. Through comprehensive experiments on various text classification tasks, VI-OOD demonstrates its effectiveness and wide applicability. Our code has been released at \url{https://github.com/liam0949/LLM-OOD}.
Paper Structure (16 sections, 5 equations, 2 figures)

This paper contains 16 sections, 5 equations, 2 figures.

Figures (2)

  • Figure 1: Investigation of OOD performance of Transformer's intermediate Hidden States: AUROC Results for 24 Layers of RoBERTaLARGE. The figure illustrates the OOD performance evaluation across multiple layers of RoBERTaLARGE. Higher values indicate better performance. The model undergoes fine-tuning on SST-2 and is assessed for OOD performance using the 20NG dataset. The four commonly used OOD scoring functions, namely MSP (red), Maha (light yellow), Cosine (blue), and Energy (green), are represented in the figure.
  • Figure 2: The architecture of our proposed framework. Our method employs an encoder-based transformer model as the backbone textual encoder. Hidden states of the [CLS] token are chosen to be textual representations. $z$ is a latent variable conditioned on the textual representations. The in-distribution (ID) classification head $p(y|z)$ and decoder $p(x^{\text{target}}|z)$ both take $z$ as the input. $\mathbf{s}$ is the hidden states combination factor and the merge representation $x^{\text{target}}$ works as the target of the decoder.