This project implements various deep learning models to classify diabetic retinopathy severity using retinal images. The models include VGG16, VGG19, Xception, and InceptionV3, all pre-trained on ImageNet and fine-tuned for this specific task.
- Project Overview
- Prerequisites
- Dataset
- Model Architecture
- Training Process
- Performance Evaluation
- Usage
- Results
- Future Work
Diabetic retinopathy is a diabetes complication that affects the eyes. Early detection is crucial for preventing vision loss. This project aims to automate the classification of diabetic retinopathy severity using machine learning techniques on retinal images.
The project requires the following libraries:
- TensorFlow
- NumPy
- Matplotlib
- Scikit-learn
- Seaborn
- Pillow
You can install these dependencies using pip:
pip install tensorflow numpy matplotlib scikit-learn seaborn pillow
The dataset should be organized in the following structure:
Diabetic Retinopathy ML Dataset/
├── train/
│ ├── class_0/
│ ├── class_1/
│ ├── class_2/
│ ├── class_3/
│ └── class_4/
└── test/
├── class_0/
├── class_1/
├── class_2/
├── class_3/
└── class_4/
Each class represents a severity level of diabetic retinopathy.
The project implements four different models:
- VGG16
- VGG19
- Xception
- InceptionV3
Each model is pre-trained on ImageNet and fine-tuned for diabetic retinopathy classification. The models are modified by:
- Removing the top layers
- Adding custom fully connected layers
- Setting specific layers to be trainable or non-trainable
The training process includes:
- Data preprocessing and augmentation
- Splitting data into train, validation, and test sets
- Model compilation with SGD optimizer and categorical crossentropy loss
- Training for a specified number of epochs with early stopping
Key hyperparameters:
- Image dimensions: 176x208
- Batch size: 16-64
- Learning rate: 0.0001
- Momentum: 0.9
- Epochs: 30 (adjustable)
The models are evaluated on three sets:
- Training set
- Validation set
- Test set
Metrics used:
- Accuracy
- Loss
To train and evaluate a model:
- Prepare your dataset in the required directory structure.
- Adjust the hyperparameters in the script if needed.
- Run the script for the desired model (VGG16, VGG19, Xception, or InceptionV3).
- The trained model will be saved in HDF5 format.
The performance of each model is printed after training, showing the accuracy on the train, validation, and test sets.
Potential improvements and extensions:
- Implement k-fold cross-validation
- Experiment with other architectures (e.g., ResNet, DenseNet)
- Implement ensemble methods
- Analyze model interpretability using techniques like Grad-CAM
- Deploy the best performing model as a web service
Feel free to contribute to this project by submitting pull requests or opening issues for bugs and feature requests.