From Robustness to Improved Generalization and Calibration in Pre-trained Language Models
Josip Jukić, Jan Šnajder
TL;DR
The paper tackles the challenge of improving generalization and calibrated uncertainty in pre-trained language models (PLMs). It introduces JacHess, a two-phase regularization that minimizes Jacobian and Hessian norms of intermediate representations with respect to inputs, using Hutchinson estimators and a dimension-subsampling scheme in the embedding space to handle discrete tokens. Evaluated on the GLUE benchmark with decoder-based models and a BERT baseline, JacHess outperforms unregularized fine-tuning and prior Jacobian/Hessian regularizers in both generalization and uncertainty calibration, with larger gains observed for bigger models. This work provides a practical approach to enhancing robustness, generalization, and reliability in PLMs by promoting representation smoothness across layers.
Abstract
Enhancing generalization and uncertainty quantification in pre-trained language models (PLMs) is crucial for their effectiveness and reliability. Building on machine learning research that established the importance of robustness for improving generalization, we investigate the role of representation smoothness, achieved via Jacobian and Hessian regularization, in enhancing PLM performance. Although such regularization methods have proven effective in computer vision, their application in natural language processing (NLP), where PLM inputs are derived from a discrete domain, poses unique challenges. We introduce a novel two-phase regularization approach, JacHess, which minimizes the norms of the Jacobian and Hessian matrices within PLM intermediate representations relative to their inputs. Our evaluation using the GLUE benchmark demonstrates that JacHess significantly improves in-domain generalization and calibration in PLMs, outperforming unregularized fine-tuning and other similar regularization methods.
