Topology-aware Robust Optimization for Out-of-distribution Generalization
Fengchun Qiao, Xi Peng
TL;DR
This work addresses OOD generalization by introducing topology-aware robust optimization (TRO), which explicitly leverages the topology among data distributions. TRO operates in two phases: topology learning (constructing a distributional graph from physical priors or data-driven diffusion-based measures) and learning on topology (constrained robust optimization over mixtures of training groups guided by a topology-derived prior). The authors prove fast convergence for convex losses and establish non-convex convergence and generalization bounds that incorporate the topology, and they demonstrate state-of-the-art results across classification, regression, and semantic segmentation with both physical and data-driven topologies. The data-driven topology, in particular, aligns with domain knowledge and enhances explainability, suggesting TRO as a practical and theoretically sound framework for reliable OOD resilience in real-world settings.
Abstract
Out-of-distribution (OOD) generalization is a challenging machine learning problem yet highly desirable in many high-stake applications. Existing methods suffer from overly pessimistic modeling with low generalization confidence. As generalizing to arbitrary test distributions is impossible, we hypothesize that further structure on the topology of distributions is crucial in developing strong OOD resilience. To this end, we propose topology-aware robust optimization (TRO) that seamlessly integrates distributional topology in a principled optimization framework. More specifically, TRO solves two optimization objectives: (1) Topology Learning which explores data manifold to uncover the distributional topology; (2) Learning on Topology which exploits the topology to constrain robust optimization for tightly-bounded generalization risks. We theoretically demonstrate the effectiveness of our approach and empirically show that it significantly outperforms the state of the arts in a wide range of tasks including classification, regression, and semantic segmentation. Moreover, we empirically find the data-driven distributional topology is consistent with domain knowledge, enhancing the explainability of our approach.
