Table of Contents
Fetching ...

Multi-Scale and Multi-Layer Contrastive Learning for Domain Generalization

Aristotelis Ballas, Christos Diou

TL;DR

The paper tackles domain generalization in image classification by exploiting multi-scale and multi-layer CNN representations to disentangle domain-invariant attributes. It introduces the M^2 framework with extraction blocks that harvest multi-scale features from intermediate layers and the M^2-CL contrastive objective that enforces invariance of class-discriminative features across domains. Across four DG benchmarks (PACS, VLCS, Office-Home, NICO), the approach achieves state-of-the-art results and is supported by saliency analyses showing focus on causally relevant object features rather than context. While effective, the method incurs memory overhead and benefits from larger batch sizes for the contrastive term, pointing to future work on efficiency and integration with causal or attention mechanisms.

Abstract

During the past decade, deep neural networks have led to fast-paced progress and significant achievements in computer vision problems, for both academia and industry. Yet despite their success, state-of-the-art image classification approaches fail to generalize well in previously unseen visual contexts, as required by many real-world applications. In this paper, we focus on this domain generalization (DG) problem and argue that the generalization ability of deep convolutional neural networks can be improved by taking advantage of multi-layer and multi-scaled representations of the network. We introduce a framework that aims at improving domain generalization of image classifiers by combining both low-level and high-level features at multiple scales, enabling the network to implicitly disentangle representations in its latent space and learn domain-invariant attributes of the depicted objects. Additionally, to further facilitate robust representation learning, we propose a novel objective function, inspired by contrastive learning, which aims at constraining the extracted representations to remain invariant under distribution shifts. We demonstrate the effectiveness of our method by evaluating on the domain generalization datasets of PACS, VLCS, Office-Home and NICO. Through extensive experimentation, we show that our model is able to surpass the performance of previous DG methods and consistently produce competitive and state-of-the-art results in all datasets

Multi-Scale and Multi-Layer Contrastive Learning for Domain Generalization

TL;DR

The paper tackles domain generalization in image classification by exploiting multi-scale and multi-layer CNN representations to disentangle domain-invariant attributes. It introduces the M^2 framework with extraction blocks that harvest multi-scale features from intermediate layers and the M^2-CL contrastive objective that enforces invariance of class-discriminative features across domains. Across four DG benchmarks (PACS, VLCS, Office-Home, NICO), the approach achieves state-of-the-art results and is supported by saliency analyses showing focus on causally relevant object features rather than context. While effective, the method incurs memory overhead and benefits from larger batch sizes for the contrastive term, pointing to future work on efficiency and integration with causal or attention mechanisms.

Abstract

During the past decade, deep neural networks have led to fast-paced progress and significant achievements in computer vision problems, for both academia and industry. Yet despite their success, state-of-the-art image classification approaches fail to generalize well in previously unseen visual contexts, as required by many real-world applications. In this paper, we focus on this domain generalization (DG) problem and argue that the generalization ability of deep convolutional neural networks can be improved by taking advantage of multi-layer and multi-scaled representations of the network. We introduce a framework that aims at improving domain generalization of image classifiers by combining both low-level and high-level features at multiple scales, enabling the network to implicitly disentangle representations in its latent space and learn domain-invariant attributes of the depicted objects. Additionally, to further facilitate robust representation learning, we propose a novel objective function, inspired by contrastive learning, which aims at constraining the extracted representations to remain invariant under distribution shifts. We demonstrate the effectiveness of our method by evaluating on the domain generalization datasets of PACS, VLCS, Office-Home and NICO. Through extensive experimentation, we show that our model is able to surpass the performance of previous DG methods and consistently produce competitive and state-of-the-art results in all datasets
Paper Structure (16 sections, 6 equations, 7 figures, 4 tables)

This paper contains 16 sections, 6 equations, 7 figures, 4 tables.

Figures (7)

  • Figure 1: In this work, the main goal is to improve the ability of a model to uncover a disentangled representation of an input image. This disentangled representation of the image can be thought of as a sequence of class-specific $a_1, a_2$ or domain-specific $d_1, d_2, d_3, d_4$ (but perhaps class-irrelevant) attributes. For example, a domain-specific feature can be attributed to a "patch of grass" in the input image, whereas the "whiskers" or "shape of ear", can be thought of as class-specific. We argue that the problems caused by domain shift between data drawn from unknown domains can be mitigated by utilizing multiple levels of information passed throughout a Convolutional Neural Network, in order to derive disentangled representations. A classifier trained on such disentangled representations can then learn to infer only on the class-specific or causal attributes of the object depicted in the image (blue rectangles).
  • Figure 2: Each object in a certain class consists of distinguishable class-specific attributes which remain invariant between domains, e.g. a cat has whiskers whether there is snow or grass in the background of the image. However, the image also contains domain-specific but note necessarily class-relevant attributes which are entangled in the representation extracted by the final convolutional layers of popular CNN architectures, such as ResNets. The goal in DG is to encode such common class-specific attributes of images, in order to enable models to identify target classes across data domains. Classic fully supervised models learning solely via empirical risk minimization (ERM) tend to correlate the latent space representations with features found in distinct domains. In the above visualization, each representation attributed to a certain class is illustrated with a different shape (circle, square, triangle), while each color (blue, green, red) corresponds to distinct data domains. To mitigate the issues caused by domain shifts in data distributions originating from different data-generating processes, in this work we propose an alternative approach which: a) attempts to derive a set of disentangled representations by extracting multiple levels of information from intermediate CNN layers of the backbone network ($\mathbf{M^2}$) and b) brings the representations extracted from images of the same label closer together, while simultaneously pushes the ones originating from different classes further away ($\mathbf{M^2}$-CL).
  • Figure 3: Visualization of the M2 architecture built on a ResNet-18. We propose extracting feature maps (black arrows) from layers across the ResNet-18 with the use of multiple extraction blocks (green boxes). The lines above the network's conv layers represent ResNet skip connections. The solid lines indicate that the feature maps retain their dimension, while the feature maps passed through the dashed connection lines are downsampled to match the dimension of the previous layer output. Our network's main functionality derives from the multiple parallel concentration pipelines in each extraction block. By utilizing feature maps from intermediate outputs of convolutional layers in the backbone model, our framework combines low-level, multi-scale features of early layers with more complex features extracted at layers further down the network. The parallel pipelines, aim at processing the extracted feature maps in a multi-scale manner, each emphasizing on a different characteristic of the object depicted in the input image. We argue that by incorporating outputs from different intermediate levels of the network, we enable the model to disentangle the invariant qualities of an image.
  • Figure 4: Visualization of the Extraction Block and Concentration Pipeline implementation. Each block can be connected to any intermediate layer of the backbone model, followed by multiple concentration pipelines. Each pipeline consists of a ${1 \times 1}$ Convolutional, Spatial Dropout, Max Pooling and Flatten layer, which is thereafter passed through a multilayer perceptron (MLP) and connected to the framework's concatenation layer as depicted in Fig. \ref{['fig:model']}.
  • Figure 5: Visualization of the M2 architecture built on a ResNet-50 model. Similar to the ResNet-18 implementation in Figure \ref{['fig:model']}, we once again extract feature maps by using multiple extraction blocks (green boxes). Our proposed extraction blocks are connected to the last convolutional layer of each of the ResNet's bottleneck blocks.
  • ...and 2 more figures