Table of Contents
Fetching ...

Gaussian Stochastic Weight Averaging for Bayesian Low-Rank Adaptation of Large Language Models

Emre Onal, Klemens Flöge, Emma Caldwell, Arsen Sheverdin, Vincent Fortuin

TL;DR

This paper tackles overconfidence and miscalibration in fine-tuned LLMs trained on limited data. It introduces Gaussian SWA (SWA-Gaussian) applied to LoRA, enabling approximate Bayesian inference by sampling from a posterior that centers on the SWA mean and uses a diagonal plus low-rank covariance to capture uncertainty. The proposed MultiSWAG approach, an ensemble of SWAG-like models restricting posterior estimation to LoRA parameters, achieves competitive accuracy and calibration compared with heavier Bayesian methods like Laplace-LoRA, while also offering improved robustness to distribution shifts. This method provides a lightweight, scalable, and effective pathway for uncertainty-aware fine-tuning of large language models in resource-constrained settings.

Abstract

Fine-tuned Large Language Models (LLMs) often suffer from overconfidence and poor calibration, particularly when fine-tuned on small datasets. To address these challenges, we propose a simple combination of Low-Rank Adaptation (LoRA) with Gaussian Stochastic Weight Averaging (SWAG), facilitating approximate Bayesian inference in LLMs. Through extensive testing across several Natural Language Processing (NLP) benchmarks, we demonstrate that our straightforward and computationally efficient approach improves model generalization and calibration competitively with comparable, more sophisticated methods for Bayesian inference in LLMs. We further show that our method exhibits greater robustness against distribution shift, as reflected in its improved performance on out-of-distribution tasks.

Gaussian Stochastic Weight Averaging for Bayesian Low-Rank Adaptation of Large Language Models

TL;DR

This paper tackles overconfidence and miscalibration in fine-tuned LLMs trained on limited data. It introduces Gaussian SWA (SWA-Gaussian) applied to LoRA, enabling approximate Bayesian inference by sampling from a posterior that centers on the SWA mean and uses a diagonal plus low-rank covariance to capture uncertainty. The proposed MultiSWAG approach, an ensemble of SWAG-like models restricting posterior estimation to LoRA parameters, achieves competitive accuracy and calibration compared with heavier Bayesian methods like Laplace-LoRA, while also offering improved robustness to distribution shifts. This method provides a lightweight, scalable, and effective pathway for uncertainty-aware fine-tuning of large language models in resource-constrained settings.

Abstract

Fine-tuned Large Language Models (LLMs) often suffer from overconfidence and poor calibration, particularly when fine-tuned on small datasets. To address these challenges, we propose a simple combination of Low-Rank Adaptation (LoRA) with Gaussian Stochastic Weight Averaging (SWAG), facilitating approximate Bayesian inference in LLMs. Through extensive testing across several Natural Language Processing (NLP) benchmarks, we demonstrate that our straightforward and computationally efficient approach improves model generalization and calibration competitively with comparable, more sophisticated methods for Bayesian inference in LLMs. We further show that our method exhibits greater robustness against distribution shift, as reflected in its improved performance on out-of-distribution tasks.
Paper Structure (18 sections, 5 equations, 1 figure, 2 tables)

This paper contains 18 sections, 5 equations, 1 figure, 2 tables.

Figures (1)

  • Figure 1: Outline of the SWAG-LoRA training and inference process. The left panel shows the LLM architecture with LoRA fine-tuning (see \ref{['subsec:lora']}). The middle and upper right panel depict the SWAG training process, where weight samples are collected across iterations of SGD to calculate the mean and an approximate covariance of the posterior over network weights (see \ref{['subsec:SWAG']}). The lower right panel demonstrates how we form our ensemble of weights for inference by sampling from the learned SWAG posterior.