-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
86 lines (73 loc) · 2.75 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import albumentations as A
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import yaml
from albumentations.pytorch import ToTensorV2
from matplotlib.backends.backend_pdf import PdfPages
from tqdm import tqdm
from config import constants as C
from models.model_zoo import BoneAgeEstModelZoo
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using {} device".format(device))
def load_model(fold_path):
pretrained_filename = fold_path
if os.path.isfile(pretrained_filename):
model = BoneAgeEstModelZoo(branch="gender", pretrained=True, lr=0.001).load_from_checkpoint(
pretrained_filename)
model.model.eval()
model.classifier.eval()
model.gender.eval()
model.eval()
return model
else:
print("No pretrained model found for testing")
return
def test_model(tc):
transform = A.Compose([
A.Resize(width=tc['image_size'], height=tc['image_size']),
A.CLAHE(),
A.Normalize(),
ToTensorV2(),
])
train_df = pd.read_csv(tc['valid_df'])
model = load_model(tc['pretrained_filename'])
# Create a PDF file
with PdfPages(tc['pdf_filename']) as pdf:
mean_error = []
for row in tqdm(train_df.iterrows()):
image = cv2.imread(row[1]['path'])
processed_image = transform(image=image)['image']
processed_image = processed_image.unsqueeze(0)
processed_image = processed_image.to(device)
boneage = torch.tensor(row[1]['boneage']).unsqueeze(0).unsqueeze(1).to(device)
gender = torch.tensor(row[1]['gender']).unsqueeze(0).unsqueeze(1).to(device)
scans = {
'image': processed_image,
'boneage': boneage,
'gender': gender
}
val_result = model(scans)
basename = os.path.basename(row[1]['path']).split('.')[0]
fig, ax = plt.subplots()
ax.imshow(image)
ax.set_title(
f"Actual Age: {boneage.item()} months\nPredicted Age: {int(val_result.item())} months")
ax.text(0.5, -0.1, f"Image: {basename}", transform=ax.transAxes, ha='center')
ax.axis('off')
mean_error.append(abs(boneage.item() - val_result.item()))
# Save the figure to the PDF file
pdf.savefig(fig, bbox_inches='tight')
plt.close()
print("Mean error: ", np.mean(mean_error))
if __name__ == "__main__":
# Load config file for training
with open(C.MODEL_TEST_CONFIG, 'r') as stream:
try:
train_config = yaml.safe_load(stream)
except yaml.YAMLError as exc:
print(exc)
test_model(train_config)