Table of Contents
Fetching ...

Two-Stage Multi-task Self-Supervised Learning for Medical Image Segmentation

Binyan Hu, A. K. Qin

TL;DR

Medical image segmentation faces data scarcity, hindering deep learning performance. The authors propose a two-stage framework that first optimizes the target segmentation task with each auxiliary self-supervised task in either joint training or pre-training mode, selecting the better model via validation, and then distills all per-task models into a single strong student through ensemble knowledge distillation. This per-task, mode-aware training avoids negative transfer and leverages complementary knowledge from diverse auxiliary tasks. Evaluation on the SIIM-ACR Pneumothorax Segmentation dataset with limited data demonstrates improved Dice scores over conventional training and existing auxiliary-task methods, highlighting data efficiency and the value of adaptive task usage. The work offers a practical approach to harness SSL and auxiliary supervision for MIS, with potential extensions in optimization strategies and task diversification.

Abstract

Medical image segmentation has been significantly advanced by deep learning (DL) techniques, though the data scarcity inherent in medical applications poses a great challenge to DL-based segmentation methods. Self-supervised learning offers a solution by creating auxiliary learning tasks from the available dataset and then leveraging the knowledge acquired from solving auxiliary tasks to help better solve the target segmentation task. Different auxiliary tasks may have different properties and thus can help the target task to different extents. It is desired to leverage their complementary advantages to enhance the overall assistance to the target task. To achieve this, existing methods often adopt a joint training paradigm, which co-solves segmentation and auxiliary tasks by integrating their losses or intermediate gradients. However, direct coupling of losses or intermediate gradients risks undesirable interference because the knowledge acquired from solving each auxiliary task at every training step may not always benefit the target task. To address this issue, we propose a two-stage training approach. In the first stage, the target segmentation task will be independently co-solved with each auxiliary task in both joint training and pre-training modes, with the better model selected via validation performance. In the second stage, the models obtained with respect to each auxiliary task are converted into a single model using an ensemble knowledge distillation method. Our approach allows for making best use of each auxiliary task to create multiple elite segmentation models and then combine them into an even more powerful model. We employed five auxiliary tasks of different proprieties in our approach and applied it to train the U-Net model on an X-ray pneumothorax segmentation dataset. Experimental results demonstrate the superiority of our approach over several existing methods.

Two-Stage Multi-task Self-Supervised Learning for Medical Image Segmentation

TL;DR

Medical image segmentation faces data scarcity, hindering deep learning performance. The authors propose a two-stage framework that first optimizes the target segmentation task with each auxiliary self-supervised task in either joint training or pre-training mode, selecting the better model via validation, and then distills all per-task models into a single strong student through ensemble knowledge distillation. This per-task, mode-aware training avoids negative transfer and leverages complementary knowledge from diverse auxiliary tasks. Evaluation on the SIIM-ACR Pneumothorax Segmentation dataset with limited data demonstrates improved Dice scores over conventional training and existing auxiliary-task methods, highlighting data efficiency and the value of adaptive task usage. The work offers a practical approach to harness SSL and auxiliary supervision for MIS, with potential extensions in optimization strategies and task diversification.

Abstract

Medical image segmentation has been significantly advanced by deep learning (DL) techniques, though the data scarcity inherent in medical applications poses a great challenge to DL-based segmentation methods. Self-supervised learning offers a solution by creating auxiliary learning tasks from the available dataset and then leveraging the knowledge acquired from solving auxiliary tasks to help better solve the target segmentation task. Different auxiliary tasks may have different properties and thus can help the target task to different extents. It is desired to leverage their complementary advantages to enhance the overall assistance to the target task. To achieve this, existing methods often adopt a joint training paradigm, which co-solves segmentation and auxiliary tasks by integrating their losses or intermediate gradients. However, direct coupling of losses or intermediate gradients risks undesirable interference because the knowledge acquired from solving each auxiliary task at every training step may not always benefit the target task. To address this issue, we propose a two-stage training approach. In the first stage, the target segmentation task will be independently co-solved with each auxiliary task in both joint training and pre-training modes, with the better model selected via validation performance. In the second stage, the models obtained with respect to each auxiliary task are converted into a single model using an ensemble knowledge distillation method. Our approach allows for making best use of each auxiliary task to create multiple elite segmentation models and then combine them into an even more powerful model. We employed five auxiliary tasks of different proprieties in our approach and applied it to train the U-Net model on an X-ray pneumothorax segmentation dataset. Experimental results demonstrate the superiority of our approach over several existing methods.
Paper Structure (20 sections, 5 equations, 1 figure, 5 tables)

This paper contains 20 sections, 5 equations, 1 figure, 5 tables.

Figures (1)

  • Figure 1: An illustration of the proposed method. Best viewed in colour. It leverages $N$ auxiliary tasks $\{\mathcal{T}_i\}_{i=1}^{N}$, created based on the training set $\mathbb{D}_{seg}^{tr}$, to boost segmentation performance. The method is composed of two training stages. In the first stage, the target segmentation task will be independently co-solved with each auxiliary task $\mathcal{T}_i$ in both joint training and pre-training modes. In the joint training mode, the model is trained to concurrently solve the segmentation and auxiliary tasks through a weighted combination of their respective losses $l_{seg}$ and $l_i$. In the pre-training mode, the model is first trained to solve the auxiliary task with $l_i$ and then transferred to be fine-tuned on the segmentation task with $l_{seg}$. The better model obtained by the two modes $f_i$ is selected via the performance evaluated on the validation set $\mathbb{D}_{seg}^{val}$. In the second stage, the models obtained with respect to each auxiliary task (as teachers) $\{f_i\}_{i=1}^N$ are converted into a single model (as student) $f_s$ using an ensemble knowledge distillation method, which matches the student's segmentation output $\hat{\mathbf{y}}_s$ with the ensemble of the teachers' segmentation outputs $\{\hat{\mathbf{y}}_i\}_{i=1}^N$. Finally, the student model $f_s$ is returned for inference.