Towards the Causal Complete Cause of Multi-Modal Representation Learning
Jingyao Wang, Siyu Zhao, Wenwen Qiang, Jiangmeng Li, Changwen Zheng, Fuchun Sun, Hui Xiong
TL;DR
The paper tackles the challenge of learning robust multi-modal representations by reframing MML through causal sufficiency and necessity. It defines the Causal Complete Cause ($C^3$) and develops identifiability results that hold without assuming exogeneity or monotonicity, leveraging an instrumental variable and a twin-network to measure $C^3$ via $C^3$ risk. The proposed Causal Complete Cause Regularization ($C^3$R) is a plug-and-play framework that enforces causal completeness by jointly optimizing an empirical $C^3$ risk, instrumental-variable guidance, and counterfactual regularization. The authors provide theoretical guarantees and demonstrate through extensive experiments across diverse datasets that $C^3$R improves average and worst-case performance, particularly under missing modalities and in the presence of spurious correlations. This work advances MML by delivering practical identifiability tools and regularization principles that promote representations grounded in causal content rather than confounded associations.
Abstract
Multi-Modal Learning (MML) aims to learn effective representations across modalities for accurate predictions. Existing methods typically focus on modality consistency and specificity to learn effective representations. However, from a causal perspective, they may lead to representations that contain insufficient and unnecessary information. To address this, we propose that effective MML representations should be causally sufficient and necessary. Considering practical issues like spurious correlations and modality conflicts, we relax the exogeneity and monotonicity assumptions prevalent in prior works and explore the concepts specific to MML, i.e., Causal Complete Cause $C^3$. We begin by defining $C^3$, which quantifies the probability of representations being causally sufficient and necessary. We then discuss the identifiability of $C^3$ and introduce an instrumental variable to support identifying $C^3$ with non-exogeneity and non-monotonicity. Building on this, we conduct the $C^3$ measurement, i.e., \(C^3\) risk. We propose a twin network to estimate it through (i) the real-world branch: utilizing the instrumental variable for sufficiency, and (ii) the hypothetical-world branch: applying gradient-based counterfactual modeling for necessity. Theoretical analyses confirm its reliability. Based on these results, we propose $C^3$ Regularization, a plug-and-play method that enforces the causal completeness of the learned representations by minimizing $C^3$ risk. Extensive experiments demonstrate its effectiveness.
