Table of Contents
Fetching ...

MSA-CNN: A Lightweight Multi-Scale CNN with Attention for Sleep Stage Classification

Stephan Goerttler, Yucheng Wang, Emadeldeen Eldele, Min Wu, Fei He

TL;DR

This work introduces MSA-CNN, a lightweight multi-scale CNN with attention for sleep stage classification on multivariate polysomnography signals. It couples a novel Multi-Scale Module (MSM) with complementary pooling, a global spatial convolution, and a Temporal Context Module (TCM) based on multi-head self-attention to capture contextual dependencies while maintaining small parameter counts. Across three public datasets, the large MSA-CNN outperforms nine state-of-the-art baselines in accuracy and Cohen's kappa, often with an order of magnitude fewer parameters, and ablation studies confirm the contributions of MSM, TCM, and multivariate inputs. A visualization tool provides interpretability by illustrating incoming and outgoing attention, and the results suggest practical potential for deployment in clinical and resource-constrained settings, with future work exploring unsupervised learning and waveform-level explanations.

Abstract

Recent advancements in machine learning-based signal analysis, coupled with open data initiatives, have fuelled efforts in automatic sleep stage classification. Despite the proliferation of classification models, few have prioritised reducing model complexity, which is a crucial factor for practical applications. In this work, we introduce Multi-Scale and Attention Convolutional Neural Network (MSA-CNN), a lightweight architecture featuring as few as ~10,000 parameters. MSA-CNN leverages a novel multi-scale module employing complementary pooling to eliminate redundant filter parameters and dense convolutions. Model complexity is further reduced by separating temporal and spatial feature extraction and using cost-effective global spatial convolutions. This separation of tasks not only reduces model complexity but also mirrors the approach used by human experts in sleep stage scoring. We evaluated both small and large configurations of MSA-CNN against nine state-of-the-art baseline models across three public datasets, treating univariate and multivariate models separately. Our evaluation, based on repeated cross-validation and re-evaluation of all baseline models, demonstrated that the large MSA-CNN outperformed all baseline models on all three datasets in terms of accuracy and Cohen's kappa, despite its significantly reduced parameter count. Lastly, we explored various model variants and conducted an in-depth analysis of the key modules and techniques, providing deeper insights into the underlying mechanisms. The code for our models, baselines, and evaluation procedures is available at https://github.com/sgoerttler/MSA-CNN.

MSA-CNN: A Lightweight Multi-Scale CNN with Attention for Sleep Stage Classification

TL;DR

This work introduces MSA-CNN, a lightweight multi-scale CNN with attention for sleep stage classification on multivariate polysomnography signals. It couples a novel Multi-Scale Module (MSM) with complementary pooling, a global spatial convolution, and a Temporal Context Module (TCM) based on multi-head self-attention to capture contextual dependencies while maintaining small parameter counts. Across three public datasets, the large MSA-CNN outperforms nine state-of-the-art baselines in accuracy and Cohen's kappa, often with an order of magnitude fewer parameters, and ablation studies confirm the contributions of MSM, TCM, and multivariate inputs. A visualization tool provides interpretability by illustrating incoming and outgoing attention, and the results suggest practical potential for deployment in clinical and resource-constrained settings, with future work exploring unsupervised learning and waveform-level explanations.

Abstract

Recent advancements in machine learning-based signal analysis, coupled with open data initiatives, have fuelled efforts in automatic sleep stage classification. Despite the proliferation of classification models, few have prioritised reducing model complexity, which is a crucial factor for practical applications. In this work, we introduce Multi-Scale and Attention Convolutional Neural Network (MSA-CNN), a lightweight architecture featuring as few as ~10,000 parameters. MSA-CNN leverages a novel multi-scale module employing complementary pooling to eliminate redundant filter parameters and dense convolutions. Model complexity is further reduced by separating temporal and spatial feature extraction and using cost-effective global spatial convolutions. This separation of tasks not only reduces model complexity but also mirrors the approach used by human experts in sleep stage scoring. We evaluated both small and large configurations of MSA-CNN against nine state-of-the-art baseline models across three public datasets, treating univariate and multivariate models separately. Our evaluation, based on repeated cross-validation and re-evaluation of all baseline models, demonstrated that the large MSA-CNN outperformed all baseline models on all three datasets in terms of accuracy and Cohen's kappa, despite its significantly reduced parameter count. Lastly, we explored various model variants and conducted an in-depth analysis of the key modules and techniques, providing deeper insights into the underlying mechanisms. The code for our models, baselines, and evaluation procedures is available at https://github.com/sgoerttler/MSA-CNN.
Paper Structure (19 sections, 5 equations, 6 figures, 6 tables)

This paper contains 19 sections, 5 equations, 6 figures, 6 tables.

Figures (6)

  • Figure 1: Full architecture of our proposed Multi-Scale and Attention Convolutional Neural Network (MSA-CNN). The Multi-Scale Module (MSM, see Figure \ref{['fig:multiscale']}) extracts high-level spectro-morphological features from the input sleep epoch. Subsequently, a global spatial convolution detects co-activation patterns across all input channels, yielding time-dependent feature tokens. These tokens are then passed to our Temporal Context Module (TCM, see Figure \ref{['fig:attention']}), which adjusts the meaning of each token depending on the surrounding context. The time average of these tokens is then interpreted by means of a fully connected layer, yielding the final classification of the input signal.
  • Figure 2: Illustration of our Multi-Scale Module (MSM) with four scales. The MSM uses complementary pooling to extract features across a wide spectral range. The first temporal convolution extracts low-level features across all scales, which are determined by the preceding pooling size. The receptive field of each convolution is shaded in the input signal. The complementary pooling allows the feature maps to be merged. A second temporal convolution integrates all four scales, yielding high-level spectro-morphological features.
  • Figure 3: Temporal context module using multi-head self-attention. Multi-head attention (Equation \ref{['eq:mha']}) requires the computation of query (blue), key (green), and value (red) maps. The sequence of attention and feed-forward layer (grey area) is carried out $N_{lay}$ times.
  • Figure 4: Performance of model variants and ablation study on the datasets ISRUC-S3 (A), Sleep-EDF-20 (B), and Sleep-EDF-78 (C) in terms of test accuracy. The proposed MSA-CNN model is configured as small (bright blue) for the ISRUC-S3 dataset and as large (dark blue) for the larger Sleep-EDF datasets, with the complementary model size serving as a model variant. For the ablation study, we changed the model to univariate (brown), replaced the multi-scale convolutions with a single uni-scale convolution (scale colour), or removed attention from the proposed model (red). Light (dark) colours indicate a modification from the small (large) MSA-CNN. Error bars depict the standard error of the mean across folds paired with the proposed model, while significant deviations from the proposed model, established using a paired t-test, are indicated above the error bar ($\cdot$: $p<0.1$, $*$: $p<0.05$, $**$: $p<0.01$, $**$$*$: $p<0.001$).
  • Figure 5: Test accuracy parameter sensitivity of MSA-CNN small (large) on dataset ISRUC-S3 (Sleep-EDF-20). (A) Mean test accuracy relative to number of channels. The shaded area shows the standard error of the mean across repetition-averaged folds. Starting from a univariate configuration, channels are successively added in a predetermined order. (B) Test accuracy relative to the number of contiguous scales in the multi-scale module. Each measurement (circle) depicts a different configuration of contiguous scales. The solid lines show the mean test accuracy, while the shaded area shows the minimum and maximum for each number of scales.
  • ...and 1 more figures