Enhance Eye Disease Detection using Learnable Probabilistic Discrete Latents in Machine Learning Architectures
Anirudh Prabhakaran, YeKun Xiao, Ching-Yu Cheng, Dianbo Liu
TL;DR
This work tackles reliable ocular disease detection from fundus images by introducing learnable probabilistic discrete latents through GFlowOut, integrated into ResNet18 and Vision Transformer backbones. By modeling the posterior over dropout masks and training with a Trajectory Balance objective, the approach improves accuracy, uncertainty estimation, and calibration under distribution shifts and noise. The method yields strong empirical gains on multiple fundus datasets, with improved OOD robustness and interpretable Grad-CAM visualizations that highlight clinically relevant regions. The findings suggest significant potential for more reliable AI-assisted ophthalmic diagnostics and broader applicability to medical imaging domains.
Abstract
Ocular diseases, including diabetic retinopathy and glaucoma, present a significant public health challenge due to their high prevalence and potential for causing vision impairment. Early and accurate diagnosis is crucial for effective treatment and management. In recent years, deep learning models have emerged as powerful tools for analysing medical images, such as retina imaging. However, challenges persist in model relibability and uncertainty estimation, which are critical for clinical decision-making. This study leverages the probabilistic framework of Generative Flow Networks (GFlowNets) to learn the posterior distribution over latent discrete dropout masks for the classification and analysis of ocular diseases using fundus images. We develop a robust and generalizable method that utilizes GFlowOut integrated with ResNet18 and ViT models as the backbone in identifying various ocular conditions. This study employs a unique set of dropout masks - none, random, bottomup, and topdown - to enhance model performance in analyzing these fundus images. Our results demonstrate that our learnable probablistic latents significantly improves accuracy, outperforming the traditional dropout approach. We utilize a gradient map calculation method, Grad-CAM, to assess model explainability, observing that the model accurately focuses on critical image regions for predictions. The integration of GFlowOut in neural networks presents a promising advancement in the automated diagnosis of ocular diseases, with implications for improving clinical workflows and patient outcomes.
