Table of Contents
Fetching ...

Medical Referring Image Segmentation via Next-Token Mask Prediction

Xinyu Chen, Yiran Wang, Gaoyang Pang, Jiafu Hao, Chentao Yue, Luping Zhou, Yonghui Li

TL;DR

This paper tackles medical referring image segmentation (MRIS) by reframing the task as autoregressive next-token prediction over a unified multimodal token stream that includes image, text, and mask representations. The authors introduce NTP-MRISeg, a pure Transformer model augmented with three MRIS-specific training strategies: Next-k Token Prediction (NkTP) to mitigate exposure bias, Token-level Contrastive Learning (TCL) to sharpen boundary distinctions and address long-tail token distributions, and Memory-based Hard Error Token (HET) optimization to emphasize persistently difficult tokens. The method tokenizes inputs with an Emu3 SBER-MoVQGAN vision tokenizer and a Qwen text tokenizer, enabling end-to-end training without modality-specific fusion modules, and achieves state-of-the-art results on QaTa-COV19 and MosMedData+ with careful ablations showing the value of each component. The work demonstrates that a streamlined, autoregressive, multimodal approach can surpass traditional MRIS pipelines while leveraging pretrained tokenizers, potentially simplifying deployment and improving robustness in clinical contexts. Key equations are notationally framed around token prediction losses and contrastive objectives, e.g., the base next-token loss and auxiliary NkTP/TCL/HET losses, all operating on discrete token sequences $\{i_n\}$ derived from images, text, and masks.

Abstract

Medical Referring Image Segmentation (MRIS) involves segmenting target regions in medical images based on natural language descriptions. While achieving promising results, recent approaches usually involve complex design of multimodal fusion or multi-stage decoders. In this work, we propose NTP-MRISeg, a novel framework that reformulates MRIS as an autoregressive next-token prediction task over a unified multimodal sequence of tokenized image, text, and mask representations. This formulation streamlines model design by eliminating the need for modality-specific fusion and external segmentation models, supports a unified architecture for end-to-end training. It also enables the use of pretrained tokenizers from emerging large-scale multimodal models, enhancing generalization and adaptability. More importantly, to address challenges under this formulation-such as exposure bias, long-tail token distributions, and fine-grained lesion edges-we propose three novel strategies: (1) a Next-k Token Prediction (NkTP) scheme to reduce cumulative prediction errors, (2) Token-level Contrastive Learning (TCL) to enhance boundary sensitivity and mitigate long-tail distribution effects, and (3) a memory-based Hard Error Token (HET) optimization strategy that emphasizes difficult tokens during training. Extensive experiments on the QaTa-COV19 and MosMedData+ datasets demonstrate that NTP-MRISeg achieves new state-of-the-art performance, offering a streamlined and effective alternative to traditional MRIS pipelines.

Medical Referring Image Segmentation via Next-Token Mask Prediction

TL;DR

This paper tackles medical referring image segmentation (MRIS) by reframing the task as autoregressive next-token prediction over a unified multimodal token stream that includes image, text, and mask representations. The authors introduce NTP-MRISeg, a pure Transformer model augmented with three MRIS-specific training strategies: Next-k Token Prediction (NkTP) to mitigate exposure bias, Token-level Contrastive Learning (TCL) to sharpen boundary distinctions and address long-tail token distributions, and Memory-based Hard Error Token (HET) optimization to emphasize persistently difficult tokens. The method tokenizes inputs with an Emu3 SBER-MoVQGAN vision tokenizer and a Qwen text tokenizer, enabling end-to-end training without modality-specific fusion modules, and achieves state-of-the-art results on QaTa-COV19 and MosMedData+ with careful ablations showing the value of each component. The work demonstrates that a streamlined, autoregressive, multimodal approach can surpass traditional MRIS pipelines while leveraging pretrained tokenizers, potentially simplifying deployment and improving robustness in clinical contexts. Key equations are notationally framed around token prediction losses and contrastive objectives, e.g., the base next-token loss and auxiliary NkTP/TCL/HET losses, all operating on discrete token sequences derived from images, text, and masks.

Abstract

Medical Referring Image Segmentation (MRIS) involves segmenting target regions in medical images based on natural language descriptions. While achieving promising results, recent approaches usually involve complex design of multimodal fusion or multi-stage decoders. In this work, we propose NTP-MRISeg, a novel framework that reformulates MRIS as an autoregressive next-token prediction task over a unified multimodal sequence of tokenized image, text, and mask representations. This formulation streamlines model design by eliminating the need for modality-specific fusion and external segmentation models, supports a unified architecture for end-to-end training. It also enables the use of pretrained tokenizers from emerging large-scale multimodal models, enhancing generalization and adaptability. More importantly, to address challenges under this formulation-such as exposure bias, long-tail token distributions, and fine-grained lesion edges-we propose three novel strategies: (1) a Next-k Token Prediction (NkTP) scheme to reduce cumulative prediction errors, (2) Token-level Contrastive Learning (TCL) to enhance boundary sensitivity and mitigate long-tail distribution effects, and (3) a memory-based Hard Error Token (HET) optimization strategy that emphasizes difficult tokens during training. Extensive experiments on the QaTa-COV19 and MosMedData+ datasets demonstrate that NTP-MRISeg achieves new state-of-the-art performance, offering a streamlined and effective alternative to traditional MRIS pipelines.

Paper Structure

This paper contains 25 sections, 15 equations, 7 figures, 4 tables.

Figures (7)

  • Figure 1: Comparison of different models for MRIS. (a) Models that integrate additional parallel U-shape architecture to aligns and fuse text features and vision featuresli2023lvithuang2024cross. (b) Dual-branch fusion architectures that apply cross attention to align and fuse text features and vision features hu2023beyondouyang2024lsms. (c) MLLM-based models that align multimodal features and use embedded representations as masks for decodinghu2024lgakoleilat2024medclip. (d) Ours: a unified MLLM-based framework that aligns features and directly uses visual tokens as mask inputs to a detokenizer.
  • Figure 2: Overall framework of NTP-MRISeg. (a) Mechanism of NTP: the model predicts each token in the sequence based on preceding tokens, with loss calculated by comparing predicted tokens against Ground Truth (GT) labels. (b) Mechanism of NkTP: the model simultaneously predicts $k$ consecutive tokens based on preceding tokens, with loss calculated across all $k$ predicted tokens against their corresponding GT. (c) Mechanism of TCL: each token uses its corresponding GT as the positive sample and the preceding m predicted tokens ($m=5$ in this example) as negative samples for contrastive learning. (d) Mechanism of HET optimization: error tokens from the previous epoch are ranked by deviation from ground truth, with the most challenging errors selected to push predictions away from historical error tokens while pulling them closer to GT.
  • Figure 3: Visualization of original and reconstructed medical images and masks using the Emu3 SBER-MoVQGAN tokenizer. (a) Original lung X-ray image, (b) Corresponding segmentation mask, (c) Original lung CT image, (d) Corresponding segmentation mask. Each image and mask is tokenized and then reconstructed from discrete tokens. The preservation of structural and boundary details demonstrates the tokenizer’s suitability for MRIS.
  • Figure 4: The visualization of the main comparison with SOTA Method on QaTa-COV19. The column titled "Medical Descriptions" denotes the input textual referring prompt, while the column titled "Image" signifies the input image. The column titled "GroundTruth" represents the ground truth segmentation target. The column titled "Ours" is the visualization result of our NTP-MRISeg. The blue area is the infected segmented by NTP-MRISeg.
  • Figure 5: The visualization of the main comparison with SOTA Method on MosMedData$+$. The column titled "Medical Descriptions" denotes the input textual referring prompt, while the column titled "Image" signifies the input image. The column titled "GroundTruth" represents the ground truth segmentation target. The column titled "Ours" is the visualization result of our NTP-MRISeg. The red area is the infected segmented by NTP-MRISeg.
  • ...and 2 more figures