Table of Contents
Fetching ...

TRAM: Bridging Trust Regions and Sharpness Aware Minimization

Tom Sherborne, Naomi Saphra, Pradeep Dasigi, Hao Peng

TL;DR

TRAM addresses the challenge of fine-tuning pre-trained models under distribution shifts by unifying sharpness-aware minimization with trust-region regularization in representation space. It constrains SAM-style parameter perturbations within a trust region defined by divergences such as $d_{\theta}$ or $d_{x}$, preserving pre-trained structure while encouraging flat parameter landscapes and smooth representations. The work develops several TRAM variants (including TRAM-Fisher) and demonstrates consistent gains over SAM and traditional trust-region baselines across vision and language tasks, including cross-dataset image classification, cross-domain language modeling, and zero-shot cross-lingual transfer. The findings suggest that jointly reducing parameter sharpness and function-space curvature yields stronger out-of-distribution generalization with minimal overhead, supporting broader applicability to domain-generalizable models.

Abstract

Sharpness-aware minimization (SAM) reports improving domain generalization by reducing the loss surface curvature in the parameter space. However, generalization during fine-tuning is often more dependent on the transferability of representations in the function space. Trust-region methods (TR) target this goal by regularizing representation curvature to reduce catastrophic forgetting of pre-trained task-agnostic information while adopting task-specific skills. We consider unifying these strategies for low curvature in both parameter space and function space to improve out-of-domain (OOD) generalization. We propose Trust Region Aware Minimization (TRAM), a SAM algorithm fine-tuning for low parameter sharpness and smooth, informative representations preserving pre-trained structure. TRAM uses a trust region bound to inform the SAM adversarial neighborhood, introducing an awareness of function curvature within optimization for flatter minima. We empirically validate TRAM in vision (cross-dataset adaptation) and text (OOD language modeling, zero-shot cross-lingual transfer) tasks where robust domain transfer and representation generality are critical. TRAM outperforms SAM- and TR-based optimization across all tasks, notably surpassing competing methods for hard transfer between anticorrelated domains. TRAM establishes a novel standard in fine-tuning for domain-generalizable models with minimal additional computation over previous sharpness-aware methods.

TRAM: Bridging Trust Regions and Sharpness Aware Minimization

TL;DR

TRAM addresses the challenge of fine-tuning pre-trained models under distribution shifts by unifying sharpness-aware minimization with trust-region regularization in representation space. It constrains SAM-style parameter perturbations within a trust region defined by divergences such as or , preserving pre-trained structure while encouraging flat parameter landscapes and smooth representations. The work develops several TRAM variants (including TRAM-Fisher) and demonstrates consistent gains over SAM and traditional trust-region baselines across vision and language tasks, including cross-dataset image classification, cross-domain language modeling, and zero-shot cross-lingual transfer. The findings suggest that jointly reducing parameter sharpness and function-space curvature yields stronger out-of-distribution generalization with minimal overhead, supporting broader applicability to domain-generalizable models.

Abstract

Sharpness-aware minimization (SAM) reports improving domain generalization by reducing the loss surface curvature in the parameter space. However, generalization during fine-tuning is often more dependent on the transferability of representations in the function space. Trust-region methods (TR) target this goal by regularizing representation curvature to reduce catastrophic forgetting of pre-trained task-agnostic information while adopting task-specific skills. We consider unifying these strategies for low curvature in both parameter space and function space to improve out-of-domain (OOD) generalization. We propose Trust Region Aware Minimization (TRAM), a SAM algorithm fine-tuning for low parameter sharpness and smooth, informative representations preserving pre-trained structure. TRAM uses a trust region bound to inform the SAM adversarial neighborhood, introducing an awareness of function curvature within optimization for flatter minima. We empirically validate TRAM in vision (cross-dataset adaptation) and text (OOD language modeling, zero-shot cross-lingual transfer) tasks where robust domain transfer and representation generality are critical. TRAM outperforms SAM- and TR-based optimization across all tasks, notably surpassing competing methods for hard transfer between anticorrelated domains. TRAM establishes a novel standard in fine-tuning for domain-generalizable models with minimal additional computation over previous sharpness-aware methods.
Paper Structure (26 sections, 11 equations, 2 figures, 14 tables, 2 algorithms)

This paper contains 26 sections, 11 equations, 2 figures, 14 tables, 2 algorithms.

Figures (2)

  • Figure 1: TRAM introduces an awareness of function curvature (i.e., the trust region) into sharpness-aware minimization. (left) TRAM estimates the size of the trust region, $d$, around $f\left(x\right)$ in green. (right) the loss contour in parameter space following pmlr-v139-kwon21b-ASAM where blue is the typical loss; red is the maximized worst-case loss for ASAM; and green is the maximized loss within the subdomain constrained for function smoothness.
  • Figure 2: Perplexity on S2ORC training domain ( Math) and zero-shot domains. We report perplexity across: (\ref{['fig:near-domains']}) domains correlated with Math as Stem domains (see \ref{['app:domain_correlation']}), (\ref{['fig:art']}) Art domain, and (\ref{['fig:philosophy']}) the Philosophy ( Phil.) domain. Each figure includes linear regression trends: the blue dotted trend is for prior work and green dashed line includes all TRAM variants. Positive slope ($\rho>0$) represents correlated domains, negative slope ($\rho<0$) represents anticorrelated domains. We report Pearson $\rho$ correlation for the blue trend noting $p<0.01$ significance.