Learnable Community-Aware Transformer for Brain Connectome Analysis with Token Clustering
Yanting Yang, Beidi Zhao, Zhuohao Ni, Yize Zhao, Xiaoxiao Li
TL;DR
The paper tackles the rigidity of predefined brain community structures in functional connectome analysis by introducing TC-BrainTF, a token-clustering transformer that learns dynamic, community-aware ROI embeddings. It employs learnable prompt tokens $P \in \mathbb{R}^{K \times E}$ and a Deep Embedded Clustering (DEC) mechanism with an orthogonal loss $\mathcal{L}_{\text{ortho}}$ to produce a soft assignment $A \in \mathbb{R}^{N \times K}$ and merged token representations, followed by a graph readout for task classification. Key contributions include a novel Token Clustering (TC) module integrated into a transformer encoder, enabling flexible determination of the number of communities $K$ and improved ASD and gender classification on ABIDE and HCP datasets, along with qualitative analyses linking clusters to neuroscience interpretations. The approach removes reliance on atlas-defined communities, enhancing adaptability and interpretability in brain connectome analysis and offering a data-driven perspective on functional organization."
Abstract
Neuroscientific research has revealed that the complex brain network can be organized into distinct functional communities, each characterized by a cohesive group of regions of interest (ROIs) with strong interconnections. These communities play a crucial role in comprehending the functional organization of the brain and its implications for neurological conditions, including Autism Spectrum Disorder (ASD) and biological differences, such as in gender. Traditional models have been constrained by the necessity of predefined community clusters, limiting their flexibility and adaptability in deciphering the brain's functional organization. Furthermore, these models were restricted by a fixed number of communities, hindering their ability to accurately represent the brain's dynamic nature. In this study, we present a token clustering brain transformer-based model ($\texttt{TC-BrainTF}$) for joint community clustering and classification. Our approach proposes a novel token clustering (TC) module based on the transformer architecture, which utilizes learnable prompt tokens with orthogonal loss where each ROI embedding is projected onto the prompt embedding space, effectively clustering ROIs into communities and reducing the dimensions of the node representation via merging with communities. Our results demonstrate that our learnable community-aware model $\texttt{TC-BrainTF}$ offers improved accuracy in identifying ASD and classifying genders through rigorous testing on ABIDE and HCP datasets. Additionally, the qualitative analysis on $\texttt{TC-BrainTF}$ has demonstrated the effectiveness of the designed TC module and its relevance to neuroscience interpretations.
