Table of Contents
Fetching ...

From Robustness to Improved Generalization and Calibration in Pre-trained Language Models

Josip Jukić, Jan Šnajder

TL;DR

The paper tackles the challenge of improving generalization and calibrated uncertainty in pre-trained language models (PLMs). It introduces JacHess, a two-phase regularization that minimizes Jacobian and Hessian norms of intermediate representations with respect to inputs, using Hutchinson estimators and a dimension-subsampling scheme in the embedding space to handle discrete tokens. Evaluated on the GLUE benchmark with decoder-based models and a BERT baseline, JacHess outperforms unregularized fine-tuning and prior Jacobian/Hessian regularizers in both generalization and uncertainty calibration, with larger gains observed for bigger models. This work provides a practical approach to enhancing robustness, generalization, and reliability in PLMs by promoting representation smoothness across layers.

Abstract

Enhancing generalization and uncertainty quantification in pre-trained language models (PLMs) is crucial for their effectiveness and reliability. Building on machine learning research that established the importance of robustness for improving generalization, we investigate the role of representation smoothness, achieved via Jacobian and Hessian regularization, in enhancing PLM performance. Although such regularization methods have proven effective in computer vision, their application in natural language processing (NLP), where PLM inputs are derived from a discrete domain, poses unique challenges. We introduce a novel two-phase regularization approach, JacHess, which minimizes the norms of the Jacobian and Hessian matrices within PLM intermediate representations relative to their inputs. Our evaluation using the GLUE benchmark demonstrates that JacHess significantly improves in-domain generalization and calibration in PLMs, outperforming unregularized fine-tuning and other similar regularization methods.

From Robustness to Improved Generalization and Calibration in Pre-trained Language Models

TL;DR

The paper tackles the challenge of improving generalization and calibrated uncertainty in pre-trained language models (PLMs). It introduces JacHess, a two-phase regularization that minimizes Jacobian and Hessian norms of intermediate representations with respect to inputs, using Hutchinson estimators and a dimension-subsampling scheme in the embedding space to handle discrete tokens. Evaluated on the GLUE benchmark with decoder-based models and a BERT baseline, JacHess outperforms unregularized fine-tuning and prior Jacobian/Hessian regularizers in both generalization and uncertainty calibration, with larger gains observed for bigger models. This work provides a practical approach to enhancing robustness, generalization, and reliability in PLMs by promoting representation smoothness across layers.

Abstract

Enhancing generalization and uncertainty quantification in pre-trained language models (PLMs) is crucial for their effectiveness and reliability. Building on machine learning research that established the importance of robustness for improving generalization, we investigate the role of representation smoothness, achieved via Jacobian and Hessian regularization, in enhancing PLM performance. Although such regularization methods have proven effective in computer vision, their application in natural language processing (NLP), where PLM inputs are derived from a discrete domain, poses unique challenges. We introduce a novel two-phase regularization approach, JacHess, which minimizes the norms of the Jacobian and Hessian matrices within PLM intermediate representations relative to their inputs. Our evaluation using the GLUE benchmark demonstrates that JacHess significantly improves in-domain generalization and calibration in PLMs, outperforming unregularized fine-tuning and other similar regularization methods.
Paper Structure (24 sections, 13 equations, 2 figures, 5 tables)

This paper contains 24 sections, 13 equations, 2 figures, 5 tables.

Figures (2)

  • Figure 1: Average generalization score with embedding perturbation. We perturb the instances in the embedding space with factor $\delta$ for controlling the degree of perturbation. The average is reported across datasets. To avoid clutter, we plot the results for the base model without regularization, the Cross-Hölder method, which generally outperforms the Jacobian approach in this aspect, and our method, JacHess. Due to space constraints, we show the results for BERT, OPT-125, and LLaMA-2-7b.
  • Figure 2: Calibration plots for binary classification datasets. We accumulate the results in all five different seeds and use eight bins for mean predicted probability. We show the calibration plots for LLaMA-2-7b.

Theorems & Definitions (2)

  • Definition 1: $L$-Lipschitz continuity
  • Definition 2: $L$-Lipschitz gradient continuity