Towards Context-Aware Domain Generalization: Understanding the Benefits and Limits of Marginal Transfer Learning
Jens Müller, Lars Kühmichel, Martin Rohbeck, Stefan T. Radev, Ullrich Köthe
TL;DR
The paper addresses how context about an input can improve predictions across unseen domains by formalizing context as a permutation-invariant set representation derived from data from the same domain. It develops necessary criteria for when context can help, analyzes robustness to distribution shifts, and introduces a set-encoder-based approach to capture contextual information, enabling environment-conditioned predictions. Empirical results on synthetic and ProDAS datasets illustrate scenarios where context improves performance and where it enables reliable out-of-distribution detection, facilitating a principled trade-off between predictive accuracy and robustness. The work offers theoretical insights and practical mechanisms for selecting between predictive and robust models in domain generalization, with implications for safer and more reliable cross-domain learning.
Abstract
In this work, we analyze the conditions under which information about the context of an input $X$ can improve the predictions of deep learning models in new domains. Following work in marginal transfer learning in Domain Generalization (DG), we formalize the notion of context as a permutation-invariant representation of a set of data points that originate from the same domain as the input itself. We offer a theoretical analysis of the conditions under which this approach can, in principle, yield benefits, and formulate two necessary criteria that can be easily verified in practice. Additionally, we contribute insights into the kind of distribution shifts for which the marginal transfer learning approach promises robustness. Empirical analysis shows that our criteria are effective in discerning both favorable and unfavorable scenarios. Finally, we demonstrate that we can reliably detect scenarios where a model is tasked with unwarranted extrapolation in out-of-distribution (OOD) domains, identifying potential failure cases. Consequently, we showcase a method to select between the most predictive and the most robust model, circumventing the well-known trade-off between predictive performance and robustness.
