Welcome to Blurry2Clear, a deep learning project that aims to automate the classification of Diabetic Retinopathy (DR) stages from retinal fundus images using powerful computer vision techniques. This project is based on transfer learning with ResNet-50 and explores various training strategies to tackle data imbalance and optimize performance on the APTOS 2019 Blindness Detection dataset.
Diabetic Retinopathy is one of the leading causes of blindness worldwide, especially in diabetic patients. Early detection is critical for timely treatment and vision preservation. This project seeks to build a robust and interpretable model that can:
- Classify retinal fundus images into five stages of Diabetic Retinopathy (0 to 4).
- Mitigate the challenges of data imbalance using advanced training techniques.
- Improve prediction performance through progressive model experimentation.
This repository contains four core Jupyter notebooks, each representing a different approach to model training and evaluation:
This is the main notebook that defines the end-to-end pipeline. It includes:
- Two versions:
- A baseline ResNet-50 classifier with default training.
- An improved pipeline with better preprocessing, early stopping, and tuned hyperparameters.
- Techniques used:
- Image normalization & augmentation (
RandomCrop,Flip,ColorJitter). - Transfer learning using pretrained
ResNet-50fromtorchvision.models. - Replacement of the final classification layer (
model.fc) to adapt to 5 classes. - Early stopping and loss tracking for stable training.
- Image normalization & augmentation (
- Output: Accuracy scores, loss curves, and a confusion matrix.
This notebook evaluates the impact of fine-tuning only the deeper layers of ResNet-50:
- Frozen layers: All layers except
layer4and the finalfclayer. - Aimed to:
- Reduce overfitting.
- Speed up training on large datasets (over 26,000 images).
- Findings:
- Test Accuracy: ~74%
- Precision and recall are skewed due to class imbalance.
- Confusion matrix provided for per-class error analysis.
This notebook addresses severe class imbalance using:
- Focal Loss:
- Helps the model focus on hard-to-classify minority classes.
- Gamma parameter set to 2.0 for down-weighting easy examples.
- WeightedRandomSampler:
- Balances the frequency of classes in each training batch.
- Limitations:
- Trained on a small subset (3,000 images) to test strategy feasibility.
- Accuracy was low but recall for rare classes improved significantly.
- Key takeaway: A promising direction with more tuning/data.
This approach produced the best overall results and serves as the final model candidate:
- Full training on all layers of ResNet-50.
- Combines:
CrossEntropyLosswith class weights computed from class distribution.WeightedRandomSamplerfor balanced batch composition.EarlyStoppingfor convergence monitoring.
- Dataset: Trained on ~21,400 preprocessed images.
- Final Results:
- Accuracy: ~78.41%
- Strong balance between precision, recall, and F1-score.
- Stable learning curves and effective generalization.
- Chosen as the preferred model for evaluation and inference.
We evaluate the performance using standard classification metrics:
- β Accuracy: Overall proportion of correct predictions.
- π Precision / Recall / F1-score (macro-averaged across 5 classes).
- π Confusion Matrix: Visual analysis of class-wise prediction errors.
- π Loss & Accuracy Curves: Monitored across epochs for both training and validation sets.
The project integrates several key AI and ML techniques:
| Component | Technique Used |
|---|---|
| Feature Extraction | Pretrained ResNet-50 |
| Transfer Learning | Frozen base layers / Fine-tuning variants |
| Image Preprocessing | Resize, Normalize, Data Augmentation |
| Loss Functions | CrossEntropyLoss, FocalLoss |
| Imbalance Handling | Weighted Losses, WeightedRandomSampler |
| Evaluation | Scikit-learn metrics, Matplotlib plots |
Install the required packages using pip:
pip install torch torchvision matplotlib scikit-learn numpy