Table of Contents
Fetching ...

GeneralizeFormer: Layer-Adaptive Model Generation across Test-Time Distribution Shifts

Sameer Ambekar, Zehao Xiao, Xiantong Zhen, Cees G. M. Snoek

TL;DR

The paper tackles test-time domain generalization where target distributions are unseen during training. It introduces GeneralizeFormer, a lightweight transformer that generates target-specific BN affine parameters and classifier weights for each batch, conditioned on source-trained weights, target features, and layer-wise gradients, while keeping convolutional weights fixed to curb computation. Trained via a meta-learning scheme that simulates distribution shifts, the approach enables on-the-fly adaptation without backpropagation on the backbone, and demonstrates strong performance across six domain-generalization benchmarks, including dynamic and multi-target scenarios. The results show improved robustness to various distribution shifts, reduced forgetting of source information, and faster inference than fine-tuning-based methods, highlighting the practicality of per-batch, layer-aware model generation for deployment under distribution shift.

Abstract

We consider the problem of test-time domain generalization, where a model is trained on several source domains and adjusted on target domains never seen during training. Different from the common methods that fine-tune the model or adjust the classifier parameters online, we propose to generate multiple layer parameters on the fly during inference by a lightweight meta-learned transformer, which we call \textit{GeneralizeFormer}. The layer-wise parameters are generated per target batch without fine-tuning or online adjustment. By doing so, our method is more effective in dynamic scenarios with multiple target distributions and also avoids forgetting valuable source distribution characteristics. Moreover, by considering layer-wise gradients, the proposed method adapts itself to various distribution shifts. To reduce the computational and time cost, we fix the convolutional parameters while only generating parameters of the Batch Normalization layers and the linear classifier. Experiments on six widely used domain generalization datasets demonstrate the benefits and abilities of the proposed method to efficiently handle various distribution shifts, generalize in dynamic scenarios, and avoid forgetting.

GeneralizeFormer: Layer-Adaptive Model Generation across Test-Time Distribution Shifts

TL;DR

The paper tackles test-time domain generalization where target distributions are unseen during training. It introduces GeneralizeFormer, a lightweight transformer that generates target-specific BN affine parameters and classifier weights for each batch, conditioned on source-trained weights, target features, and layer-wise gradients, while keeping convolutional weights fixed to curb computation. Trained via a meta-learning scheme that simulates distribution shifts, the approach enables on-the-fly adaptation without backpropagation on the backbone, and demonstrates strong performance across six domain-generalization benchmarks, including dynamic and multi-target scenarios. The results show improved robustness to various distribution shifts, reduced forgetting of source information, and faster inference than fine-tuning-based methods, highlighting the practicality of per-batch, layer-aware model generation for deployment under distribution shift.

Abstract

We consider the problem of test-time domain generalization, where a model is trained on several source domains and adjusted on target domains never seen during training. Different from the common methods that fine-tune the model or adjust the classifier parameters online, we propose to generate multiple layer parameters on the fly during inference by a lightweight meta-learned transformer, which we call \textit{GeneralizeFormer}. The layer-wise parameters are generated per target batch without fine-tuning or online adjustment. By doing so, our method is more effective in dynamic scenarios with multiple target distributions and also avoids forgetting valuable source distribution characteristics. Moreover, by considering layer-wise gradients, the proposed method adapts itself to various distribution shifts. To reduce the computational and time cost, we fix the convolutional parameters while only generating parameters of the Batch Normalization layers and the linear classifier. Experiments on six widely used domain generalization datasets demonstrate the benefits and abilities of the proposed method to efficiently handle various distribution shifts, generalize in dynamic scenarios, and avoid forgetting.

Paper Structure

This paper contains 11 sections, 6 equations, 6 figures, 8 tables, 2 algorithms.

Figures (6)

  • Figure 1: Illustration of test-time generalization methods. (a) Fine-tuning methods update the model online with large batches of samples. (b) Classifier adjustment methods feedforwardly update the classifier according to the test data. (c) Our method adaptively generates target-specific parameters in different layers according to the target distribution, enabling it to handle various distribution shifts without fine-tuning.
  • Figure 2: Illustration of GeneralizeFormer. We generate the model parameters of the classifiers and the Batch Normalization layers at different levels with GeneralizeFormer. The GeneralizeFormer takes the source-trained parameters, target features, and layer-wise gradients as input and outputs the target-specific parameters. By considering the layer-wise information, the method adaptively generates target parameters for different levels of layers, enabling the model to handle various distribution shifts.
  • Figure 3: Visualizations of adaptive model generationusing ResNet-18 on PACS. (a) For input level shifts, based on the samples, our method focuses on generating the low-level layers. (b) Similarly, for feature-level shifts that consist of subpopulations, our method mainly changes the middle layers. (c) For the output level shifts, due to category shifts our method changes more on the high-level layers while also generating the initial layers since there are also input-level shifts in this setting.
  • Figure 4: Generalization in different scenarios: (a) small batch sizes, (b) source forgetting, and (c) multiple target distributions. Our method performs well in small batch sizes and complex scenarios with multiple distributions. The method also avoids source forgetting.
  • Figure 5: Visualization of generated weights on PACS. Each row visualizes a 28x28 filter from the batch norm layer for a sample image from the photo domain. We show the (a) Generated weights by GeneralizeFormer (b) Real weights.
  • ...and 1 more figures