Table of Contents
Fetching ...

Ordinal Classification with Distance Regularization for Robust Brain Age Prediction

Jay Shah, Md Mahfuzur Rahman Siddiquee, Yi Su, Teresa Wu, Baoxin Li

TL;DR

This work tackles the regression-to-the-mean bias in brain age prediction from MRI by reframing the problem as ordinal classification and introducing the ORDER loss. By encoding the ordinal relationships of age labels into the learned feature space through a Manhattan-distance-based regularization term, the approach preserves age ordering while maintaining high feature entropy. On a large, multi-site healthy lifespan dataset, the method achieves lower MAE and stronger ordinality with reduced systematic bias, and it demonstrates improved discrimination of Alzheimer's disease stages on an independent ADNI cohort. The proposed framework, validated via ablations, suggests a powerful general strategy for robust aging biomarkers and potentially other continuous-target tasks facing RTM bias.

Abstract

Age is one of the major known risk factors for Alzheimer's Disease (AD). Detecting AD early is crucial for effective treatment and preventing irreversible brain damage. Brain age, a measure derived from brain imaging reflecting structural changes due to aging, may have the potential to identify AD onset, assess disease risk, and plan targeted interventions. Deep learning-based regression techniques to predict brain age from magnetic resonance imaging (MRI) scans have shown great accuracy recently. However, these methods are subject to an inherent regression to the mean effect, which causes a systematic bias resulting in an overestimation of brain age in young subjects and underestimation in old subjects. This weakens the reliability of predicted brain age as a valid biomarker for downstream clinical applications. Here, we reformulate the brain age prediction task from regression to classification to address the issue of systematic bias. Recognizing the importance of preserving ordinal information from ages to understand aging trajectory and monitor aging longitudinally, we propose a novel ORdinal Distance Encoded Regularization (ORDER) loss that incorporates the order of age labels, enhancing the model's ability to capture age-related patterns. Extensive experiments and ablation studies demonstrate that this framework reduces systematic bias, outperforms state-of-art methods by statistically significant margins, and can better capture subtle differences between clinical groups in an independent AD dataset. Our implementation is publicly available at https://github.com/jaygshah/Robust-Brain-Age-Prediction.

Ordinal Classification with Distance Regularization for Robust Brain Age Prediction

TL;DR

This work tackles the regression-to-the-mean bias in brain age prediction from MRI by reframing the problem as ordinal classification and introducing the ORDER loss. By encoding the ordinal relationships of age labels into the learned feature space through a Manhattan-distance-based regularization term, the approach preserves age ordering while maintaining high feature entropy. On a large, multi-site healthy lifespan dataset, the method achieves lower MAE and stronger ordinality with reduced systematic bias, and it demonstrates improved discrimination of Alzheimer's disease stages on an independent ADNI cohort. The proposed framework, validated via ablations, suggests a powerful general strategy for robust aging biomarkers and potentially other continuous-target tasks facing RTM bias.

Abstract

Age is one of the major known risk factors for Alzheimer's Disease (AD). Detecting AD early is crucial for effective treatment and preventing irreversible brain damage. Brain age, a measure derived from brain imaging reflecting structural changes due to aging, may have the potential to identify AD onset, assess disease risk, and plan targeted interventions. Deep learning-based regression techniques to predict brain age from magnetic resonance imaging (MRI) scans have shown great accuracy recently. However, these methods are subject to an inherent regression to the mean effect, which causes a systematic bias resulting in an overestimation of brain age in young subjects and underestimation in old subjects. This weakens the reliability of predicted brain age as a valid biomarker for downstream clinical applications. Here, we reformulate the brain age prediction task from regression to classification to address the issue of systematic bias. Recognizing the importance of preserving ordinal information from ages to understand aging trajectory and monitor aging longitudinally, we propose a novel ORdinal Distance Encoded Regularization (ORDER) loss that incorporates the order of age labels, enhancing the model's ability to capture age-related patterns. Extensive experiments and ablation studies demonstrate that this framework reduces systematic bias, outperforms state-of-art methods by statistically significant margins, and can better capture subtle differences between clinical groups in an independent AD dataset. Our implementation is publicly available at https://github.com/jaygshah/Robust-Brain-Age-Prediction.
Paper Structure (17 sections, 7 equations, 4 figures, 5 tables)

This paper contains 17 sections, 7 equations, 4 figures, 5 tables.

Figures (4)

  • Figure 1: Standard cross-entropy vs. cross-entropy with ORDER loss: Cross entropy loss (left) encourages the model to learn high entropy feature representations where embeddings are spread out. However, it fails to capture ordinal information from labels. Our proposed ORDER loss with cross entropy (right, Eq. \ref{['eq:our_loss_total']}) preserves ordinality by spreading the features proportional to Manhattan distance between normalized features weighted by absolute age difference. The illustrated example (right) shows embedding space where learned representations of MRI scans with ages $20$, $40$, and $80$ are distributed apart from one another, with distances proportional to absolute age differences.
  • Figure 2: Overview of proposed brain age prediction framework. (a) A 3D ResNet-18 model is trained using lifespan cohort with cross entropy and ORDER losses. Age is calculated as the weighted average of class probabilities from the softmax classifier. (b) At inference, the Brain Age Gap Estimate (BrainAGE) is calculated as the difference between predicted biological age and actual chronological age. (c) The trajectory plot offers a visual interpretation of predicted BrainAGE and its associations with aging patterns. The preclinical AD stage is when the patient behaves cognitively normal, but underlying changes in the brain due to accelerated aging happening at a subtle rate can be captured using BrainAGE.
  • Figure 3: t-SNE visualization of embeddings from models' penultimate layer: (a) When using MSE loss, embeddings maintain ordinal relationships but are tightly packed, resulting in a low-entropy feature space (b) MSE with Euclidean distance loss spreads out embeddings but struggles to preserve ordinal relationships accurately (c) Cross-entropy (CE) further spreads embeddings, creating a high-entropy space, but at the cost of losing ordinal information (d) Mean-variance loss combined with cross-entropy creates a high-entropy feature space and slightly improves ordinality (Tab. \ref{['tab: ordinal_bias_results']}). (e) ORDER loss combined with cross-entropy achieves the best balance: it accurately preserves ordinality, maintains a high-entropy space, and improves overall performance. Embeddings are colored-coded based on their ground truth age values $[10-95]$.
  • Figure 4: Heatmap of statistical significances between the five clinical groups of AD calculated as $p$ values from a t-test on predicted BrainAGE from respective groups, for MSE and cross-entropy with ORDER loss models.