Dataset can be found at https://www.kaggle.com/iarunava/cell-images-for-detecting-malaria
- Download the dataset
- Import the dataset into PyTorch using ImageFolder (resize and crop each image)
dataset = ImageFolder(data_dir, tt.Compose([tt.Resize(64),
tt.RandomCrop(64),
tt.ToTensor()]))
- Split dataset into two groups: training set and validating set
train_ds, valid_ds = random_split(dataset, [train_size, val_size])
- Prepare the set for training (using DataLoader and make_grid)
- Move the dataset to the GPU
- Define a neural network (ResNet9)
- Train the model
- Plot losses against epochs
- Test model and record the results