Table of Contents
Fetching ...

AI-CNet3D: An Anatomically-Informed Cross-Attention Network with Multi-Task Consistency Fine-tuning for 3D Glaucoma Classification

Roshan Kenia, Anfei Li, Rishabh Srivastava, Kaveri A. Thakoor

TL;DR

AI-CNet3D introduces an anatomically informed cross-attention framework for 3D OCT glaucoma classification by linking superior–inferior hemiretinas and macula–ONH regions. The model combines a lightweight 3D CNN backbone with targeted cross-attention blocks and a Channel Attention Representation (CARE) visualization, further refined through multi-task fine-tuning that enforces consistency between CARE and 3D Grad-CAM. Empirical results on two OCT datasets show state-of-the-art performance with substantially fewer parameters and GFLOPS compared to pure transformer models. The approach enhances interpretability and anatomical coherence, offering a data-efficient, clinically relevant solution for 3D OCT analysis and potential biomarker discovery.

Abstract

Glaucoma is a progressive eye disease that leads to optic nerve damage, causing irreversible vision loss if left untreated. Optical coherence tomography (OCT) has become a crucial tool for glaucoma diagnosis, offering high-resolution 3D scans of the retina and optic nerve. However, the conventional practice of condensing information from 3D OCT volumes into 2D reports often results in the loss of key structural details. To address this, we propose a novel hybrid deep learning model that integrates cross-attention mechanisms into a 3D convolutional neural network (CNN), enabling the extraction of critical features from the superior and inferior hemiretinas, as well as from the optic nerve head (ONH) and macula, within OCT volumes. We introduce Channel Attention REpresentations (CAREs) to visualize cross-attention outputs and leverage them for consistency-based multi-task fine-tuning, aligning them with Gradient-Weighted Class Activation Maps (Grad-CAMs) from the CNN's final convolutional layer to enhance performance, interpretability, and anatomical coherence. We have named this model AI-CNet3D (AI-`See'-Net3D) to reflect its design as an Anatomically-Informed Cross-attention Network operating on 3D data. By dividing the volume along two axes and applying cross-attention, our model enhances glaucoma classification by capturing asymmetries between the hemiretinal regions while integrating information from the optic nerve head and macula. We validate our approach on two large datasets, showing that it outperforms state-of-the-art attention and convolutional models across all key metrics. Finally, our model is computationally efficient, reducing the parameter count by one-hundred--fold compared to other attention mechanisms while maintaining high diagnostic performance and comparable GFLOPS.

AI-CNet3D: An Anatomically-Informed Cross-Attention Network with Multi-Task Consistency Fine-tuning for 3D Glaucoma Classification

TL;DR

AI-CNet3D introduces an anatomically informed cross-attention framework for 3D OCT glaucoma classification by linking superior–inferior hemiretinas and macula–ONH regions. The model combines a lightweight 3D CNN backbone with targeted cross-attention blocks and a Channel Attention Representation (CARE) visualization, further refined through multi-task fine-tuning that enforces consistency between CARE and 3D Grad-CAM. Empirical results on two OCT datasets show state-of-the-art performance with substantially fewer parameters and GFLOPS compared to pure transformer models. The approach enhances interpretability and anatomical coherence, offering a data-efficient, clinically relevant solution for 3D OCT analysis and potential biomarker discovery.

Abstract

Glaucoma is a progressive eye disease that leads to optic nerve damage, causing irreversible vision loss if left untreated. Optical coherence tomography (OCT) has become a crucial tool for glaucoma diagnosis, offering high-resolution 3D scans of the retina and optic nerve. However, the conventional practice of condensing information from 3D OCT volumes into 2D reports often results in the loss of key structural details. To address this, we propose a novel hybrid deep learning model that integrates cross-attention mechanisms into a 3D convolutional neural network (CNN), enabling the extraction of critical features from the superior and inferior hemiretinas, as well as from the optic nerve head (ONH) and macula, within OCT volumes. We introduce Channel Attention REpresentations (CAREs) to visualize cross-attention outputs and leverage them for consistency-based multi-task fine-tuning, aligning them with Gradient-Weighted Class Activation Maps (Grad-CAMs) from the CNN's final convolutional layer to enhance performance, interpretability, and anatomical coherence. We have named this model AI-CNet3D (AI-`See'-Net3D) to reflect its design as an Anatomically-Informed Cross-attention Network operating on 3D data. By dividing the volume along two axes and applying cross-attention, our model enhances glaucoma classification by capturing asymmetries between the hemiretinal regions while integrating information from the optic nerve head and macula. We validate our approach on two large datasets, showing that it outperforms state-of-the-art attention and convolutional models across all key metrics. Finally, our model is computationally efficient, reducing the parameter count by one-hundred--fold compared to other attention mechanisms while maintaining high diagnostic performance and comparable GFLOPS.

Paper Structure

This paper contains 34 sections, 16 equations, 6 figures, 9 tables.

Figures (6)

  • Figure 1: An example of how an OCT volume from Dataset 1 can be split along different axes to separate the anatomy in the volume. $CA_H$ is computed using the inferior (I) and superior (S) hemiretinas. $CA_{NA}$ is computed using the ONH (O) and macula (M). $CA_{H-NA}$ is computed using the inferior ONH (IO), inferior macula (IM), superior ONH (SO), and superior macula (SM).
  • Figure 2: Our cross-attention mechanism operates between two pairs of subsections from the feature volume, as highlighted in the green box. In this example, we are computing cross-attention between the superior (S) and inferior (I) hemiretina split used for $CA_{H}$. For $CA_{H-NA}$ (not visualized here), we would repeat the calculation performed in the green box for each pair of quarter-volumes and then concatenate the results before performing the skip connection addition.
  • Figure 3: Visualization of our AI-CNet3D architecture (with the channel dimension omitted from visualization). We apply multiple layers of convolution along with two cross-attention blocks. Filter banks of size 32 are used consistently across the model. When training with multi-task fine-tuning, we utilize the last cross-attention and convolutional layers highlighted in red for alignment.
  • Figure 4: Comparison of CARE and Grad-CAM visualizations from our top-performing AI-CNet3D model before and after multi-task fine-tuning (MTFT) on Dataset 1 (Topcon) for the first four rows and Dataset 2 (Zeiss) for the last four rows. The heatmaps use a scale of intensities to represent importance, with yellow indicating the highest relevance, red indicating moderate importance, and the absence of activation representing the least relevance. The No MTFT model was trained with standard BCE loss for 250 epochs, while the MTFT model was fine-tuned for an additional 250 epochs using a combination of unsupervised MSE loss and supervised BCE loss. Early indicates attention from the second convolutional and first attention layer, whereas Last corresponds to the final convolutional and attention layers. True positive examples are shown above the green line, and true negatives are displayed below. Axial slices appear to the left of the cyan line, while coronal slices are to the right. After MTFT, attention maps from the Last layers exhibit greater consistency, highlighting more stable and interpretable representations.
  • Figure 5: A comparison of what our CARE highlights versus the important regions identified by the clinician. In the final column, we subtract the clinical regions of interest (RoIs) from the CARE heatmap to highlight how our method captures additional features within the volumes beyond the clinical ROIs.
  • ...and 1 more figures