Table of Contents
Fetching ...

MC2SleepNet: Multi-modal Cross-masking with Contrastive Learning for Sleep Stage Classification

Younghoon Na, Hyun Keun Ahn, Hyun-Kyung Lee, Yoongeol Lee, Seung Hun Oh, Hongkwon Kim, Jeong-Gun Lee

TL;DR

This paper introduces MC2SleepNet, a multi-modal sleep stage classifier that jointly processes raw EEG and spectrogram inputs using a CNN and a Transformer backbone, respectively. It advances multi-modal learning through epoch-level InfoNCE contrastive alignment and a novel sequence-level Cross-Masking that enables cross-attention between modalities, followed by a fine-tuning stage with frozen backbones. The model achieves state-of-the-art accuracy on SleepEDF-78 (84.6%) and SHHS (88.6%), demonstrating strong generalization across dataset sizes and improved performance on challenging sleep stages. The approach offers a scalable framework for robust sleep staging with cross-modal supervision and self-supervised learning, potentially improving automated PSG analysis in clinical settings.

Abstract

Sleep profoundly affects our health, and sleep deficiency or disorders can cause physical and mental problems. Despite significant findings from previous studies, challenges persist in optimizing deep learning models, especially in multi-modal learning for high-accuracy sleep stage classification. Our research introduces MC2SleepNet (Multi-modal Cross-masking with Contrastive learning for Sleep stage classification Network). It aims to facilitate the effective collaboration between Convolutional Neural Networks (CNNs) and Transformer architectures for multi-modal training with the help of contrastive learning and cross-masking. Raw single channel EEG signals and corresponding spectrogram data provide differently characterized modalities for multi-modal learning. Our MC2SleepNet has achieved state-of-the-art performance with an accuracy of both 84.6% on the SleepEDF-78 and 88.6% accuracy on the Sleep Heart Health Study (SHHS). These results demonstrate the effective generalization of our proposed network across both small and large datasets.

MC2SleepNet: Multi-modal Cross-masking with Contrastive Learning for Sleep Stage Classification

TL;DR

This paper introduces MC2SleepNet, a multi-modal sleep stage classifier that jointly processes raw EEG and spectrogram inputs using a CNN and a Transformer backbone, respectively. It advances multi-modal learning through epoch-level InfoNCE contrastive alignment and a novel sequence-level Cross-Masking that enables cross-attention between modalities, followed by a fine-tuning stage with frozen backbones. The model achieves state-of-the-art accuracy on SleepEDF-78 (84.6%) and SHHS (88.6%), demonstrating strong generalization across dataset sizes and improved performance on challenging sleep stages. The approach offers a scalable framework for robust sleep staging with cross-modal supervision and self-supervised learning, potentially improving automated PSG analysis in clinical settings.

Abstract

Sleep profoundly affects our health, and sleep deficiency or disorders can cause physical and mental problems. Despite significant findings from previous studies, challenges persist in optimizing deep learning models, especially in multi-modal learning for high-accuracy sleep stage classification. Our research introduces MC2SleepNet (Multi-modal Cross-masking with Contrastive learning for Sleep stage classification Network). It aims to facilitate the effective collaboration between Convolutional Neural Networks (CNNs) and Transformer architectures for multi-modal training with the help of contrastive learning and cross-masking. Raw single channel EEG signals and corresponding spectrogram data provide differently characterized modalities for multi-modal learning. Our MC2SleepNet has achieved state-of-the-art performance with an accuracy of both 84.6% on the SleepEDF-78 and 88.6% accuracy on the Sleep Heart Health Study (SHHS). These results demonstrate the effective generalization of our proposed network across both small and large datasets.

Paper Structure

This paper contains 14 sections, 22 equations, 3 figures, 5 tables.

Figures (3)

  • Figure 1: The MC$^2$SleepNet processes both raw signals and spectrograms as input. The raw signals are passed through the CNN-based backbone, while the spectrograms are fed into a Transformer-based backbone. We carry out the pre-training steps concurrently across the granularity of epochs and sequences. To mitigate potential discrepancies between the features obtained from the data of each modality, our MC$^2$SleepNet employs InfoNCE loss. Then, a random masking strategy with 50% probability forces the model to refer to other features from other modality data through the cross-attention layers.
  • Figure 2: A process for generating a spectrogram from a raw signal.
  • Figure 3: Confusion matrices for two datasets and masking ratio (M).