Table of Contents
Fetching ...

Training Transitive and Commutative Multimodal Transformers with LoReTTa

Manuel Tran, Yashin Dicente Cid, Amal Lahiani, Fabian J. Theis, Tingying Peng, Eldad Klaiman

TL;DR

LoReTTa introduces a self-supervised multimodal pretraining strategy that leverages commutativity and transitivity to connect disjoint modality pairs through a linking modality. By unifying causal (next-token) and masked modeling, and by learning bidirectional transitions across modalities, it can handle unseen modality combinations such as (A, C) and (A, B, C) without requiring all combinations during training. The method is validated on synthetic SVL-MNIST, TCGA-OMICS, and MUGEN-GAME, where LoReTTa consistently outperforms strong baselines on missing-modality tasks and cross-modal generation. These results suggest a scalable path toward universal multimodal transformers capable of leveraging partial data in safety-critical domains, with attention to resource usage and potential societal impacts.

Abstract

Training multimodal foundation models is challenging due to the limited availability of multimodal datasets. While many public datasets pair images with text, few combine images with audio or text with audio. Even rarer are datasets that align all three modalities at once. Critical domains such as healthcare, infrastructure, or transportation are particularly affected by missing modalities. This makes it difficult to integrate all modalities into a large pre-trained neural network that can be used out-of-the-box or fine-tuned for different downstream tasks. We introduce LoReTTa (Linking mOdalities with a tRansitive and commutativE pre-Training sTrAtegy) to address this understudied problem. Our self-supervised framework unifies causal modeling and masked modeling with the rules of commutativity and transitivity. This allows us to transition within and between modalities. As a result, our pre-trained models are better at exploring the true underlying joint probability distribution. Given a dataset containing only the disjoint combinations (A, B) and (B, C), LoReTTa can model the relation A <-> C with A <-> B <-> C. In particular, we show that a transformer pre-trained with LoReTTa can handle any mixture of modalities at inference time, including the never-seen pair (A, C) and the triplet (A, B, C). We extensively evaluate our approach on a synthetic, medical, and reinforcement learning dataset. Across different domains, our universal multimodal transformer consistently outperforms strong baselines such as GPT, BERT, and CLIP on tasks involving the missing modality tuple.

Training Transitive and Commutative Multimodal Transformers with LoReTTa

TL;DR

LoReTTa introduces a self-supervised multimodal pretraining strategy that leverages commutativity and transitivity to connect disjoint modality pairs through a linking modality. By unifying causal (next-token) and masked modeling, and by learning bidirectional transitions across modalities, it can handle unseen modality combinations such as (A, C) and (A, B, C) without requiring all combinations during training. The method is validated on synthetic SVL-MNIST, TCGA-OMICS, and MUGEN-GAME, where LoReTTa consistently outperforms strong baselines on missing-modality tasks and cross-modal generation. These results suggest a scalable path toward universal multimodal transformers capable of leveraging partial data in safety-critical domains, with attention to resource usage and potential societal impacts.

Abstract

Training multimodal foundation models is challenging due to the limited availability of multimodal datasets. While many public datasets pair images with text, few combine images with audio or text with audio. Even rarer are datasets that align all three modalities at once. Critical domains such as healthcare, infrastructure, or transportation are particularly affected by missing modalities. This makes it difficult to integrate all modalities into a large pre-trained neural network that can be used out-of-the-box or fine-tuned for different downstream tasks. We introduce LoReTTa (Linking mOdalities with a tRansitive and commutativE pre-Training sTrAtegy) to address this understudied problem. Our self-supervised framework unifies causal modeling and masked modeling with the rules of commutativity and transitivity. This allows us to transition within and between modalities. As a result, our pre-trained models are better at exploring the true underlying joint probability distribution. Given a dataset containing only the disjoint combinations (A, B) and (B, C), LoReTTa can model the relation A <-> C with A <-> B <-> C. In particular, we show that a transformer pre-trained with LoReTTa can handle any mixture of modalities at inference time, including the never-seen pair (A, C) and the triplet (A, B, C). We extensively evaluate our approach on a synthetic, medical, and reinforcement learning dataset. Across different domains, our universal multimodal transformer consistently outperforms strong baselines such as GPT, BERT, and CLIP on tasks involving the missing modality tuple.
Paper Structure (10 sections, 5 equations, 2 figures, 4 tables)

This paper contains 10 sections, 5 equations, 2 figures, 4 tables.

Figures (2)

  • Figure 1: Venn diagrams showing the relationship between datasets with different modalities $A$, $B$, and $C$. Overlapping datasets indicate that the dataset contains samples with aligned modalities (i.e., audio, image, and text files belonging to the same concept). While recent work has mostly focused on datasets where at least some samples have all modality combinations available (a, b), we investigate the case where some modality combinations, e.g. $(A, C)$ and $(A, B, C)$, are missing entirely (c).
  • Figure 2: LoReTTa consists of two novel self-supervised strategies: commutative and transitive modeling. (a) In commutative modeling, we apply causal modeling in a commutative manner to generate modality $A$ from $B$ and modality $B$ from $A$ -- given the aligned input sample $(A, B)$. For the aligned but disjoint data point $(B', C)$, we apply the same technique. (b) To ensure that the model learns bidirectional relations between the input tokens, we apply a modified variant of generative pre-training called causal masked modeling. (c, d) Next, we use transitive pre-training to learn any missing conditional joint distributions. The idea is simple. We randomly select a sample and use the linking modality $B$ to predict the missing modality $C$, which is then used to reconstruct the existing modality $A$. The last step is crucial because it ensures that all modalities are properly aligned.