Table of Contents
Fetching ...

PSA-MIL: A Probabilistic Spatial Attention-Based Multiple Instance Learning for Whole Slide Image Classification

Sharon Peled, Yosef E. Maruvka, Moti Freiman

TL;DR

PSA-MIL tackles the challenge of integrating spatial context into whole-slide image MIL by reframing self-attention as a probabilistic posterior with learnable spatial priors. It introduces dynamic, distance-decay priors, a diversity loss to diversify attention heads, and a spatial pruning mechanism to dramatically reduce computation while preserving spatial dependencies. Empirical results across cancer subtyping, metastatic detection, and survival prediction show state-of-the-art performance with lower computational cost, underscoring the practical impact of data-driven spatial modeling in WSI analysis. This approach provides a principled, adaptable framework for capturing complex tissue structures, with potential for broader clinical deployment.

Abstract

Whole Slide Images (WSIs) are high-resolution digital scans widely used in medical diagnostics. WSI classification is typically approached using Multiple Instance Learning (MIL), where the slide is partitioned into tiles treated as interconnected instances. While attention-based MIL methods aim to identify the most informative tiles, they often fail to fully exploit the spatial relationships among them, potentially overlooking intricate tissue structures crucial for accurate diagnosis. To address this limitation, we propose Probabilistic Spatial Attention MIL (PSA-MIL), a novel attention-based MIL framework that integrates spatial context into the attention mechanism through learnable distance-decayed priors, formulated within a probabilistic interpretation of self-attention as a posterior distribution. This formulation enables a dynamic inference of spatial relationships during training, eliminating the need for predefined assumptions often imposed by previous approaches. Additionally, we suggest a spatial pruning strategy for the posterior, effectively reducing self-attention's quadratic complexity. To further enhance spatial modeling, we introduce a diversity loss that encourages variation among attention heads, ensuring each captures distinct spatial representations. Together, PSA-MIL enables a more data-driven and adaptive integration of spatial context, moving beyond predefined constraints. We achieve state-of-the-art performance across both contextual and non-contextual baselines, while significantly reducing computational costs.

PSA-MIL: A Probabilistic Spatial Attention-Based Multiple Instance Learning for Whole Slide Image Classification

TL;DR

PSA-MIL tackles the challenge of integrating spatial context into whole-slide image MIL by reframing self-attention as a probabilistic posterior with learnable spatial priors. It introduces dynamic, distance-decay priors, a diversity loss to diversify attention heads, and a spatial pruning mechanism to dramatically reduce computation while preserving spatial dependencies. Empirical results across cancer subtyping, metastatic detection, and survival prediction show state-of-the-art performance with lower computational cost, underscoring the practical impact of data-driven spatial modeling in WSI analysis. This approach provides a principled, adaptable framework for capturing complex tissue structures, with potential for broader clinical deployment.

Abstract

Whole Slide Images (WSIs) are high-resolution digital scans widely used in medical diagnostics. WSI classification is typically approached using Multiple Instance Learning (MIL), where the slide is partitioned into tiles treated as interconnected instances. While attention-based MIL methods aim to identify the most informative tiles, they often fail to fully exploit the spatial relationships among them, potentially overlooking intricate tissue structures crucial for accurate diagnosis. To address this limitation, we propose Probabilistic Spatial Attention MIL (PSA-MIL), a novel attention-based MIL framework that integrates spatial context into the attention mechanism through learnable distance-decayed priors, formulated within a probabilistic interpretation of self-attention as a posterior distribution. This formulation enables a dynamic inference of spatial relationships during training, eliminating the need for predefined assumptions often imposed by previous approaches. Additionally, we suggest a spatial pruning strategy for the posterior, effectively reducing self-attention's quadratic complexity. To further enhance spatial modeling, we introduce a diversity loss that encourages variation among attention heads, ensuring each captures distinct spatial representations. Together, PSA-MIL enables a more data-driven and adaptive integration of spatial context, moving beyond predefined constraints. We achieve state-of-the-art performance across both contextual and non-contextual baselines, while significantly reducing computational costs.

Paper Structure

This paper contains 25 sections, 10 equations, 6 figures, 4 tables.

Figures (6)

  • Figure 1: Comparison of spatial context modeling approaches in WSI analysis. (a) Permutation-invariant MIL treats all patches equally, disregarding spatial relationships. (b) Contextual MIL restricts interactions to a predefined neighborhood around the anchor patch, enforcing a fixed spatial structure. (c) Positional encoding incorporates spatial bias into tile representations, providing limited spatial awareness. Its unstructured nature lacks alignment with physical distances, which hurts interpretability. (d) Our proposed PSA-MIL dynamically learns spatial relationships, starting with an initial distribution over neighboring patches and evolving throughout training, in contrast to previous methods that maintain a fixed spatial structure.
  • Figure 2: PSA-MIL Overview: 1–2. Tissue regions are cropped and encoded using a pretrained feature extractor, producing tile representations. 3-3.1. A multi-head spatial self-attention mechanism is employed to generate informative, spatially correlated tile representations. Within this block, the pairwise distance matrix (bottom left) undergoes dynamic spatial pruning (\ref{['spatial_prune']}) via the learned function $f^{-1}(\tau|\theta)$, producing a spatially pruned distance matrix. The distance-decayed matrix is then computed using $f(d|\theta)$. Both $f^{-1}(\tau|\theta)$ and $f(d|\theta)$ are learnable components parameterized by $\theta$, enabling adaptive pruning and decay. Simultaneously, query, key, and value matrices are generated. The posterior attention distribution is then computed using \ref{['final_posterior']} and applied to the value matrix, producing refined tile embeddings that encode spatial interactions. 4. These embeddings are subsequently pooled via an attention-based aggregation mechanism and passed through a classification head to generate slide-level predictions. The final loss function consists of both the classification loss and a diversity loss (\ref{['diversity']}) applied to the multi-head attention mechanism to encourage the capture of diverse spatial patterns.
  • Figure 3: The Effect of Diversity Loss on Multi-Head Attention. (a-b). Locality evolution: In \ref{['spatial_prune']}, we describe how our spatial attention can be interpreted as dynamic local attention. Without diversity loss, all heads converge to similar locality values, whereas with diversity loss, the locality values diverge. (c-d). Inter-head similarity: The similarity metric measures average $\ell_2$ token similarity (smoothed) across heads. Without diversity loss, similarity remains relatively high, while with diversity loss, similarity decreases. See \ref{['training_dynamics']} for additional details.
  • Figure 4: Visual comparisons of different ROIs produced by various MIL models for metastatic detection. PSA-MIL accurately highlights the ROI; other methods show imprecise localization.
  • Figure 5: Subtyping performance (AUC) vs. mean FLOPs per batch during training (log scale) for contextual MIL models on TCGA-CRC, with bubble size representing parameter count. PSA-MIL delivered the best performance with a notably smaller computational footprint.
  • ...and 1 more figures