Table of Contents
Fetching ...

Low-Rank Continual Pyramid Vision Transformer: Incrementally Segment Whole-Body Organs in CT with Light-Weighted Adaptation

Vince Zhu, Zhanghexuan Ji, Dazhou Guo, Puyang Wang, Yingda Xia, Le Lu, Xianghua Ye, Wei Zhu, Dakai Jin

TL;DR

This work proposes a new continual whole-body organ segmentation model with light-weighted low-rank adaptation (LoRA), which achieves high segmentation accuracy, closely reaching the PVT and nnUNet upper bounds, and significantly outperforms other regularization-based CSS methods.

Abstract

Deep segmentation networks achieve high performance when trained on specific datasets. However, in clinical practice, it is often desirable that pretrained segmentation models can be dynamically extended to enable segmenting new organs without access to previous training datasets or without training from scratch. This would ensure a much more efficient model development and deployment paradigm accounting for the patient privacy and data storage issues. This clinically preferred process can be viewed as a continual semantic segmentation (CSS) problem. Previous CSS works would either experience catastrophic forgetting or lead to unaffordable memory costs as models expand. In this work, we propose a new continual whole-body organ segmentation model with light-weighted low-rank adaptation (LoRA). We first train and freeze a pyramid vision transformer (PVT) base segmentation model on the initial task, then continually add light-weighted trainable LoRA parameters to the frozen model for each new learning task. Through a holistically exploration of the architecture modification, we identify three most important layers (i.e., patch-embedding, multi-head attention and feed forward layers) that are critical in adapting to the new segmentation tasks, while retaining the majority of the pretrained parameters fixed. Our proposed model continually segments new organs without catastrophic forgetting and meanwhile maintaining a low parameter increasing rate. Continually trained and tested on four datasets covering different body parts of a total of 121 organs, results show that our model achieves high segmentation accuracy, closely reaching the PVT and nnUNet upper bounds, and significantly outperforms other regularization-based CSS methods. When comparing to the leading architecture-based CSS method, our model has a substantial lower parameter increasing rate while achieving comparable performance.

Low-Rank Continual Pyramid Vision Transformer: Incrementally Segment Whole-Body Organs in CT with Light-Weighted Adaptation

TL;DR

This work proposes a new continual whole-body organ segmentation model with light-weighted low-rank adaptation (LoRA), which achieves high segmentation accuracy, closely reaching the PVT and nnUNet upper bounds, and significantly outperforms other regularization-based CSS methods.

Abstract

Deep segmentation networks achieve high performance when trained on specific datasets. However, in clinical practice, it is often desirable that pretrained segmentation models can be dynamically extended to enable segmenting new organs without access to previous training datasets or without training from scratch. This would ensure a much more efficient model development and deployment paradigm accounting for the patient privacy and data storage issues. This clinically preferred process can be viewed as a continual semantic segmentation (CSS) problem. Previous CSS works would either experience catastrophic forgetting or lead to unaffordable memory costs as models expand. In this work, we propose a new continual whole-body organ segmentation model with light-weighted low-rank adaptation (LoRA). We first train and freeze a pyramid vision transformer (PVT) base segmentation model on the initial task, then continually add light-weighted trainable LoRA parameters to the frozen model for each new learning task. Through a holistically exploration of the architecture modification, we identify three most important layers (i.e., patch-embedding, multi-head attention and feed forward layers) that are critical in adapting to the new segmentation tasks, while retaining the majority of the pretrained parameters fixed. Our proposed model continually segments new organs without catastrophic forgetting and meanwhile maintaining a low parameter increasing rate. Continually trained and tested on four datasets covering different body parts of a total of 121 organs, results show that our model achieves high segmentation accuracy, closely reaching the PVT and nnUNet upper bounds, and significantly outperforms other regularization-based CSS methods. When comparing to the leading architecture-based CSS method, our model has a substantial lower parameter increasing rate while achieving comparable performance.
Paper Structure (9 sections, 3 equations, 2 figures, 2 tables)

This paper contains 9 sections, 3 equations, 2 figures, 2 tables.

Figures (2)

  • Figure 1: Illustration of the continual multi-organ segmentation (a). At each continual learning step, only the previously trained model is available (green arrow). Previous datasets are not accessible. Illustration of the segmentation performance versus parameter increasing rate of continual multi-organ segmentation methods.
  • Figure 2: Overall framework of the proposed low-rank continual pyramid vision transformer (LoCo-PVT) network for continual whole-body organ segmentation, which is composed of a stack of encoder and decoder blocks, where each block contains a patch embedding (PE) layer and multiple LoCo-PVT layers. Encoder PE layer (LoCo-PE) has a convolution layer with stride 2 for downsampling, while decoder PE layer (Deconv-PE) uses deconvolution layer for upsampling instead. Continual LoRA is added on linear layers for Q/V projection in multi-head attention and feed-forward network in LoCo-PVT, and is also added on convolution layers (LoCo-Conv) in LoCo-PE. The base network is frozen (colored in blue) after training the inital task 0. At each following continual learning step, a set of trainable LoRA parameters and a new segmentation output layer (colored in red) are added for new task adaptation.