Table of Contents
Fetching ...

BrainRotViT: Transformer-ResNet Hybrid for Explainable Modeling of Brain Aging from 3D sMRI

Wasif Jalal, Md Nafiu Rahman, Atif Hasan Rahman, M. Sohel Rahman

TL;DR

BrainRotViT tackles the challenge of accurate, generalizable brain age estimation from heterogeneous multi-site sMRI by coupling a Vision Transformer encoder, pretrained on age–sex composite classes, with a lightweight residual CNN regression head that operates on a 2D pseudo-image formed from ViT embeddings. The model achieves strong validation performance ($\text{MAE}=\$3.34$ years, $r=0.98$, $\rho=0.97$, $R^2=0.95$) across 11 datasets and demonstrates robust cross-cohort generalization ($\text{MAE}$ between $3.77$ and $5.04$) on four independent cohorts. Interpretability is integrated via guided backpropagation and ViT patch mapping to produce slice-level and 3D attention volumes, highlighting aging-relevant regions such as the cerebellar vermis, precentral/postcentral gyri, temporal lobes, and medial superior frontal gyrus. The findings link brain age gaps to neurological conditions (e.g., AD, MCI, ASD), offering a scalable, efficient, and explainable framework that bridges CNN- and transformer-based approaches for aging and neurodegeneration research.

Abstract

Accurate brain age estimation from structural MRI is a valuable biomarker for studying aging and neurodegeneration. Traditional regression and CNN-based methods face limitations such as manual feature engineering, limited receptive fields, and overfitting on heterogeneous data. Pure transformer models, while effective, require large datasets and high computational cost. We propose Brain ResNet over trained Vision Transformer (BrainRotViT), a hybrid architecture that combines the global context modeling of vision transformers (ViT) with the local refinement of residual CNNs. A ViT encoder is first trained on an auxiliary age and sex classification task to learn slice-level features. The frozen encoder is then applied to all sagittal slices to generate a 2D matrix of embedding vectors, which is fed into a residual CNN regressor that incorporates subject sex at the final fully-connected layer to estimate continuous brain age. Our method achieves an MAE of 3.34 years (Pearson $r=0.98$, Spearman $ρ=0.97$, $R^2=0.95$) on validation across 11 MRI datasets encompassing more than 130 acquisition sites, outperforming baseline and state-of-the-art models. It also generalizes well across 4 independent cohorts with MAEs between 3.77 and 5.04 years. Analyses on the brain age gap (the difference between the predicted age and actual age) show that aging patterns are associated with Alzheimer's disease, cognitive impairment, and autism spectrum disorder. Model attention maps highlight aging-associated regions of the brain, notably the cerebellar vermis, precentral and postcentral gyri, temporal lobes, and medial superior frontal gyrus. Our results demonstrate that this method provides an efficient, interpretable, and generalizable framework for brain-age prediction, bridging the gap between CNN- and transformer-based approaches while opening new avenues for aging and neurodegeneration research.

BrainRotViT: Transformer-ResNet Hybrid for Explainable Modeling of Brain Aging from 3D sMRI

TL;DR

BrainRotViT tackles the challenge of accurate, generalizable brain age estimation from heterogeneous multi-site sMRI by coupling a Vision Transformer encoder, pretrained on age–sex composite classes, with a lightweight residual CNN regression head that operates on a 2D pseudo-image formed from ViT embeddings. The model achieves strong validation performance (3.34r=0.98\rho=0.97R^2=0.95\text{MAE}3.775.04$) on four independent cohorts. Interpretability is integrated via guided backpropagation and ViT patch mapping to produce slice-level and 3D attention volumes, highlighting aging-relevant regions such as the cerebellar vermis, precentral/postcentral gyri, temporal lobes, and medial superior frontal gyrus. The findings link brain age gaps to neurological conditions (e.g., AD, MCI, ASD), offering a scalable, efficient, and explainable framework that bridges CNN- and transformer-based approaches for aging and neurodegeneration research.

Abstract

Accurate brain age estimation from structural MRI is a valuable biomarker for studying aging and neurodegeneration. Traditional regression and CNN-based methods face limitations such as manual feature engineering, limited receptive fields, and overfitting on heterogeneous data. Pure transformer models, while effective, require large datasets and high computational cost. We propose Brain ResNet over trained Vision Transformer (BrainRotViT), a hybrid architecture that combines the global context modeling of vision transformers (ViT) with the local refinement of residual CNNs. A ViT encoder is first trained on an auxiliary age and sex classification task to learn slice-level features. The frozen encoder is then applied to all sagittal slices to generate a 2D matrix of embedding vectors, which is fed into a residual CNN regressor that incorporates subject sex at the final fully-connected layer to estimate continuous brain age. Our method achieves an MAE of 3.34 years (Pearson , Spearman , ) on validation across 11 MRI datasets encompassing more than 130 acquisition sites, outperforming baseline and state-of-the-art models. It also generalizes well across 4 independent cohorts with MAEs between 3.77 and 5.04 years. Analyses on the brain age gap (the difference between the predicted age and actual age) show that aging patterns are associated with Alzheimer's disease, cognitive impairment, and autism spectrum disorder. Model attention maps highlight aging-associated regions of the brain, notably the cerebellar vermis, precentral and postcentral gyri, temporal lobes, and medial superior frontal gyrus. Our results demonstrate that this method provides an efficient, interpretable, and generalizable framework for brain-age prediction, bridging the gap between CNN- and transformer-based approaches while opening new avenues for aging and neurodegeneration research.

Paper Structure

This paper contains 50 sections, 11 equations, 5 figures, 17 tables.

Figures (5)

  • Figure 1: Overview of the proposed framework for 3D brain MRI analysis. The model integrates a Vision Transformer (ViT) encoder to extract slice-level embeddings with a lightweight 2D convolutional regression head to aggregate volumetric information. By applying convolutions over transformer-derived embeddings, the framework jointly models global contextual and local spatial patterns, enabling subsequent analyses on brain age gaps and thus linking aging with neurological conditions such as Alzheimer’s disease, mild cognitive impairment, and autism spectrum disorder.
  • Figure 2: Implementation details of methodology.(a): Chronological age distribution of the 4086 training samples used in the study. (b): Chronological age distribution of the 1022 validation-set samples used in the study. (c): Pre-processing routine to convert NIfTI-format 3D T1w MRI images to 2D sagittal slices that are used as inputs to our modeling modeling framework. (d): Heatmap of 160 pixel $\times$ 768 pixel 2D feature map (ViT embedding vectors of 160 slices concatenated in order; log-scaled, min-max scaled, and rotated for better visualization). This is essentially a representation of the input that the residual CNN portion of the network receives for one of the samples. Slices change along the vertical axis, while the horizontal axis is the dimension of the embedding vectors. There is an appearance of vertical symmetry across slices, while adjacent slices appear to have similar embeddings, thus creating the appearance of vertical lines of similarly intense positions in the embedding vectors. (e): Heatmap of 160$\times$160 cosine similarity matrix between each slice of 3D samples; averaged across all samples in the study. The pixel grid from left to right, and from top to bottom, both represent the indices of slices 1 to 160. The principal diagonal (top left to bottom right) represents the cosine similarity of each slice with itself, which is 1, as the most intense red. The intensity spreading from the principal diagonal indicates that the vision transformer produces similar embeddings for neighboring slices, while the intensity along the other diagonal suggests that the embeddings of symmetrically antipodal slices are very similar, i.e., the transformer captures the sagittal symmetry of the brain effectively.
  • Figure 3: (a): Scatter plot of predicted age of validation samples from 11 datasets (MAE = 3.34 years, Pearson $r = 0.98$, Spearman $\rho = 0.97$, $R^2 = 0.95$). (b): Scatter plot of predicted age of independent test samples from 4 datasets (MAE = 4.72 years, Pearson $r = 0.88$, Spearman $\rho = 0.90$, $R^2 = 0.83$). (c): Scatter plot of validation on 5,581 ADNI subjects only. (d): Scatter plot of validation on 1,141 ABIDE-II subjects only. (e): Proportion of AD and AD/MCI subjects with extreme positive aging ($\textit{BAG}>1 \textit{std.dev.}$) and extreme negative aging ($\textit{BAG}<-1 \textit{std.dev.}$) visualized. (f) Proportion of AD/MCI subjects with extreme positive aging ($\textit{BAG}>1 \textit{std.dev.}$) visualized against other AD/MCI subjects in ADNI. (g) Proportion of ASD subjects with extreme positive aging ($\textit{BAG}>1 \textit{std.dev.}$) visualized against other ASD subjects in ABIDE-II.
  • Figure 4: Model Interpretation.(a): Example of per-slice attention map extracted from vision transformer. (b) Mid-slice section views of fused 3D attention map along axial, coronal, and sagittal planes (from left to right).(c): Saliency map of trained residual CNN extracted through guided back propagation. (d)–(k): 3D attention map previews by age range.
  • Figure 5: Ablation study results.(a) Comparison of ViT backbone architectures; (b) Impact of age–sex group granularity; (c) Effect of slice count per volume; (d) Front-end architecture comparison; (e) Sex information integration methods; (f) Optimal depth of convolutional blocks; (g) Kernel size configurations; (h) Activation function comparison; (i) Sex fusion effectiveness; (j) Learning rate sensitivity; (k) Impact of residual connections; (l) Loss function comparison on validation and cross-cohort test sets.