-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_model.py
75 lines (54 loc) · 2 KB
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import torch
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
from utils.model import UNet
from utils import config
def log(message, dots=True):
message = f'[INFO] {message}'
if dots:
message += '...'
print(message)
def prepare_plot(img, original_mask, predicted_mask):
figure, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 10))
ax[0].imshow(img, cmap='gray')
ax[0].set_title('Image')
ax[1].imshow(original_mask, cmap='gray')
ax[1].set_title('Original Mask')
ax[2].imshow(predicted_mask, cmap='gray')
ax[2].set_title('Predicted Mask')
figure.tight_layout()
# figure.show()
plt.show()
def predict(model: UNet, img_path: str):
img = cv.imread(img_path)
img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
img = cv.resize(img, (128, 128))
original_img = img.copy()
img = img.astype(np.float32) / 255.
file_name = img_path.split(os.path.sep)[-1]
ground_truth_path = os.path.join(config.MASK_DATASET_PATH, file_name)
original_mask = cv.imread(ground_truth_path, 0)
original_mask = cv.resize(
original_mask,
(config.INPUT_IMAGE_HEIGHT, config.INPUT_IMAGE_HEIGHT)
)
img = np.transpose(img, (2, 0, 1))
img = np.expand_dims(img, 0)
img = torch.from_numpy(img).to(config.DEVICE)
predicted_mask = model(img).squeeze()
predicted_mask = torch.sigmoid(predicted_mask)
predicted_mask = predicted_mask.cpu().detach().numpy()
# Filter out the weak predictions and convert them to integers
predicted_mask = (predicted_mask > config.THRESHOLD) * 255
predicted_mask = predicted_mask.astype(np.uint8)
prepare_plot(original_img, original_mask, predicted_mask)
log('Loading up test image paths')
with open(config.TEST_PATHS) as file:
img_paths = file.read().strip().split('\n')
img_paths = np.random.choice(img_paths, size=10)
log('Loading up the UNet model')
unet: UNet = torch.load(config.MODEL_PATH).to(config.DEVICE)
for path in img_paths:
predict(unet, path)