Table of Contents
Fetching ...

Enhance Eye Disease Detection using Learnable Probabilistic Discrete Latents in Machine Learning Architectures

Anirudh Prabhakaran, YeKun Xiao, Ching-Yu Cheng, Dianbo Liu

TL;DR

This work tackles reliable ocular disease detection from fundus images by introducing learnable probabilistic discrete latents through GFlowOut, integrated into ResNet18 and Vision Transformer backbones. By modeling the posterior over dropout masks and training with a Trajectory Balance objective, the approach improves accuracy, uncertainty estimation, and calibration under distribution shifts and noise. The method yields strong empirical gains on multiple fundus datasets, with improved OOD robustness and interpretable Grad-CAM visualizations that highlight clinically relevant regions. The findings suggest significant potential for more reliable AI-assisted ophthalmic diagnostics and broader applicability to medical imaging domains.

Abstract

Ocular diseases, including diabetic retinopathy and glaucoma, present a significant public health challenge due to their high prevalence and potential for causing vision impairment. Early and accurate diagnosis is crucial for effective treatment and management. In recent years, deep learning models have emerged as powerful tools for analysing medical images, such as retina imaging. However, challenges persist in model relibability and uncertainty estimation, which are critical for clinical decision-making. This study leverages the probabilistic framework of Generative Flow Networks (GFlowNets) to learn the posterior distribution over latent discrete dropout masks for the classification and analysis of ocular diseases using fundus images. We develop a robust and generalizable method that utilizes GFlowOut integrated with ResNet18 and ViT models as the backbone in identifying various ocular conditions. This study employs a unique set of dropout masks - none, random, bottomup, and topdown - to enhance model performance in analyzing these fundus images. Our results demonstrate that our learnable probablistic latents significantly improves accuracy, outperforming the traditional dropout approach. We utilize a gradient map calculation method, Grad-CAM, to assess model explainability, observing that the model accurately focuses on critical image regions for predictions. The integration of GFlowOut in neural networks presents a promising advancement in the automated diagnosis of ocular diseases, with implications for improving clinical workflows and patient outcomes.

Enhance Eye Disease Detection using Learnable Probabilistic Discrete Latents in Machine Learning Architectures

TL;DR

This work tackles reliable ocular disease detection from fundus images by introducing learnable probabilistic discrete latents through GFlowOut, integrated into ResNet18 and Vision Transformer backbones. By modeling the posterior over dropout masks and training with a Trajectory Balance objective, the approach improves accuracy, uncertainty estimation, and calibration under distribution shifts and noise. The method yields strong empirical gains on multiple fundus datasets, with improved OOD robustness and interpretable Grad-CAM visualizations that highlight clinically relevant regions. The findings suggest significant potential for more reliable AI-assisted ophthalmic diagnostics and broader applicability to medical imaging domains.

Abstract

Ocular diseases, including diabetic retinopathy and glaucoma, present a significant public health challenge due to their high prevalence and potential for causing vision impairment. Early and accurate diagnosis is crucial for effective treatment and management. In recent years, deep learning models have emerged as powerful tools for analysing medical images, such as retina imaging. However, challenges persist in model relibability and uncertainty estimation, which are critical for clinical decision-making. This study leverages the probabilistic framework of Generative Flow Networks (GFlowNets) to learn the posterior distribution over latent discrete dropout masks for the classification and analysis of ocular diseases using fundus images. We develop a robust and generalizable method that utilizes GFlowOut integrated with ResNet18 and ViT models as the backbone in identifying various ocular conditions. This study employs a unique set of dropout masks - none, random, bottomup, and topdown - to enhance model performance in analyzing these fundus images. Our results demonstrate that our learnable probablistic latents significantly improves accuracy, outperforming the traditional dropout approach. We utilize a gradient map calculation method, Grad-CAM, to assess model explainability, observing that the model accurately focuses on critical image regions for predictions. The integration of GFlowOut in neural networks presents a promising advancement in the automated diagnosis of ocular diseases, with implications for improving clinical workflows and patient outcomes.
Paper Structure (22 sections, 8 equations, 5 figures, 8 tables, 1 algorithm)

This paper contains 22 sections, 8 equations, 5 figures, 8 tables, 1 algorithm.

Figures (5)

  • Figure 1: In the vision transformer architecture, we apply GFlowOut, a learnable dropout technique, in the transformer encoder. This allows us to learn posterior distribution over dropout masks tailored to our dataset, improving performance of the model.
  • Figure 2: Class count of data points in the RFMiD Dataset
  • Figure 3: These plots show the loss curves and accuracy curves for the different models used. The top row has the metrics for ResNet18 model, and the bottom row has the metrics for the Vision Transformer model. We also plot metrics for each of the masks evaluated: none, random, topdown and bottomup.
  • Figure 4: Fundus images from datasets with the minimum and maximum entropy. The top row consists of diabetic and normal fundus images, respectively, which have the minimum entropy. The bottom row consists of diabetic and normal fundus images, respectively, which has maximum entropy. We note that the model has highest confidence in its predictions when the image is clear, and the least confidence when the image is under or over-exposed.
  • Figure 5: GradCAM analysis of the attention maps of the Vision Transformer. The top row consists of fundus images of diabetic and normal patients with minimum entropy. The bottom row consists of fundus images of diabetic and normal patients with maximum entropy. On top of these images, we apply the attention map computed using GradCAM to understand which parts are considered important by the model.