A multiclass problem
Option 2 for GA Capstone 22-Oct-2022
More than 3.4 million Americans aged 40 years and older are blind (having a visual acuity of 20/200 or less or a visual field on 20 degrees or less) or visually impaired (having a visual acuity of 20/40 or less). Other estimates of “vision problems” range as high as 21 million, and a total of 80 million Americans have potentially blinding eye diseases. The major causes of vision loss are cataracts, age-related macular degeneration, diabetic retinopathy, and glaucoma. (CDC et al)
In Singapore, diabetic retinopathy is the leading cause of vision loss in working-age adults. According to a study led by researchers at the Singapore Eye Research Institute (SERI) in 2015, diabetic retinopathy has to date claimed the sight of more than 600 Singaporeans, the loss of an eye in 8,000, and visual impairment in a further 17,500. The problem worsens as the risk of blindness increases fifteen-fold for Singaporeans aged 50 to 80 and above. As age steals away the senses, vision loss is perhaps the most devastating, as it increases the risk of falls, depression and even premature death. (SNEC et al)
Fictional scenario
Singapore Health Promotion Board is searching for innovative, cutting-edge technological solutions to facilitate mass eye screening of common eye diseases in the general adult population. They hope to automate interpretation of retina screening images, shorten the time taken to flag high-risk persons for further health assessment. This would allow early intervention, lower risk of disease progression and lower overall healthcare cost burden. By minimising number of people in the population with severe eye disease, we minimise the use of more costly therapies.
Using 4,217 retinal images collected from various sources like IDRiD, Oculur recognition, HRF etc, construct a machine learning algorithm to help classify a RGB retina image to any of 4 classes (cataract, diabetic retinopathy, glaucoma and normal).
Potential stakeholders: public health authorities, eye clinics, neighbourhood eye screening centres e.g. optical shops
Eye disease | No. of images |
---|---|
Cataract | 1038 |
Diabetic retinopathy | 1098 |
Glaucoma | 1007 |
Normal | 1074 |
*** | *** |
Total | 4217 |
Repo
├── model_notebooks
| ├── EfficientNet.ipynb
| └── InceptionResNetV2.ipynb
├── history
| ├── ENet_history.csv
| ├── ENet-a_history.csv
| ├── IRN_history.csv
| └── IRN-a_history.csv
├── working_models
| ├── ENet_ep20_val0.311.zip
| └── IRN_ep13_val0.347.zip
├── assets
| ├── EfficientNet.png
| └── InceptionResNetV2.png
├── README.md
└── slides.pdf
For simplicity, the following will be based on EfficientNet notebook since both notebooks are very similar but with different ML models
- Import image dataset from Kaggle using Kaggle API
- Explore images, image hashing
- Train, test, validation datasets split
- Visualise image augmentation sample
- Use EfficientNetV2S for Transfer learning as Base model with custom metrics
- Evaluate loss and accuracy trends
- Add data augmentation layers to model and run
- Evaluate loss and accuracy trends
- Confusion matrices, Classification report
- Evaluate misclassified images
- Saliency Maps
- Activation Heatmaps
- ROC, PR curves to compare with InceptionResNetV2 model
- Recommendations
- Download dataset into local drive using
!kaggle datasets download gunavenkatdoddi/eye-diseases-classification
from Kaggle - Explore images
- Image count for each class / total image count
- Image attributes
(512, 512, 3)
,RGB
, pixels range(0, 255)
- Image hashing:
4215
unique images
- Split dataset using
splitfolders
library at a ratio of(.5, .25, .25)
- Read folders into tensor dataset formats using
image_dataset_from_directory()
, setshuffle=False
for prediction evaluation
- Read folders into tensor dataset formats using
- Visualise preprocessed images
- Resizing
(160,160)
for EfficientNet or(299,299)
for InceptionResNetV2 - Rescaling pixels to
(0, 1)
- allows faster model convergence - Data augmentation using
RandomFlip("horizontal")
,RandomRotation(0.1)
,RandomContrast(0.1)
- Resizing
- Create base model using transfer learning without data augmentation layers
- EfficientNetV2S:
input_shape=(160,160,3)
,include_top=False
,weights='imagenet'
,pooling='max'
base_model.trainable=True
- Compile model:
optimizer='Adamax'
,loss='categorical_crossentropy'
- Metrics:
categorical_accuracy
,precision
,recall
,auc
, custom functionf1_score
- Metrics:
- EfficientNetV2S:
- Evaluate base model
- Accuracy 93.4%
- Loss 0.311
- Accuracy generalisation 6.29%
- Add data augmentation layers to model
- Evaluate augmented model - took more epoch runs to reach minimal loss
- Accuracy 94.2%
- Loss 0.312
- Accuracy generalisation 5.66%
- Confusion matrix
- Classification report
- Some misclassified images
Reasons for misclassification:- Poorly taken
- Wrong source labelled
- Abnormality not obvious, too mild (e.g. mild cataract)
- Saliency map - Glaucoma example
- Activation heatmap - Glaucoma example
Optic disc region is activating the model the most, which is desired for detecting glaucoma - Comparison between EfficientNetV2S and InceptionResNetV2 models
InceptionResNetV2 model is marginally superior to EfficientNet, however inferior in identifying images with glaucoma.
Class | Test set metrics | EfficientNetV2S | InceptionResNetV2 |
---|---|---|---|
- | Accuracy | 93.00% | 92.43% |
Cataract | Precision | 94.07% | 90.91% |
Recall | 91.54% | 96.15% | |
F1 score | 92.79% | 93.46% | |
DRE | Precision | 99.64% | 99.28% |
Recall | 100% | 100% | |
F1 score | 99.82% | 99.64% | |
Glaucoma | Precision | 90.04% | 86.96% |
Recall | 89.33% | 86.96% | |
F1 score | 89.68% | 86.96% | |
Normal | Precision | 88.09% | 92.06% |
Recall | 90.71% | 86.25% | |
F1 score | 89.38% | 89.06% |
DRE: diabetic retinopathy
Limitations:
- Model cannot be used to identify retina images with multiple abnormalities. e.g. retina with both diabetic neuropathy and cataract. It would only be able to give 1 class as output, not multiple classes
- Model is unable to pinpoint the region or abnormal feature found in the retina image - can only classify image as a whole.
- Retina images with poor resolution, poor image capture (no retina or optic disc visualised) have higher chance of being wrongly classified.
Conclusions:
- As most of the probabilities for wrongly classified images are quite high (>0.9), increasing probability threshold may not be useful in improving evaluation metrics.
- This model is best used with other screening modalities to increase precision and accuracy of diagnosis e.g.
- tonometer (for anterior chamber pressure) for glaucoma
- snellen chart (visual acuity) for overall vision ability
- auto perimeter visual field analyser (glaucoma)
- Optical coherence tomography (OCT) - (identify retina layers, and subretinal deposits)
Suggestions:
- Try individual colour channels for input images, e.g. just using green channels as input. May help model train better
- Remove black background for input images, by adding alpha channel for transparency. May help model learn diabetic neuropathy more accurately.
- Train a different model on just the diabetic neuropathy images vs normal images (keeping the black background), and see if the model is activated by correct features in the images with diabetic neuropathy rather than background.
- Improve the input images by trimming off images with non-obvious abnormalities, distortion, or features not distinguishable for the particular class. E.g. glaucoma - need to have distinct optic disc cataract - need to have distinct blurring or dullness of retina normal - need to have good quality images, no confusing peripheral dullness or blurring or obscured optic disc that can be easily confused with glaucoma or cataract.
- Try Meta Pseudo Label model, a semi-supervised learning method that achieves a new state-of-the-art top-1 accuracy of 90.2% on ImageNet, and see if model learns better with less confusion and less background activation