-
Notifications
You must be signed in to change notification settings - Fork 13
/
predict.py
43 lines (39 loc) · 1.41 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
from utils.helpers import *
import warnings
from PIL import Image
from torchvision import transforms
#from torchsummary import summary
def image_transform(imagepath):
test_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
image = Image.open(imagepath)
imagetensor = test_transforms(image)
return imagetensor
def predict(imagepath, verbose=False):
if not verbose:
warnings.filterwarnings('ignore')
model_path = './models/catvdog.pth'
try:
checks_if_model_is_loaded = type(model)
except:
model = load_model(model_path)
model.eval()
#summary(model, input_size=(3,244,244))
if verbose:
print("Model Loaded..")
image = image_transform(imagepath)
image1 = image[None,:,:,:]
ps=torch.exp(model(image1))
topconf, topclass = ps.topk(1, dim=1)
if topclass.item() == 1:
return {'class':'dog','confidence':str(topconf.item())}
else:
return {'class':'cat','confidence':str(topconf.item())}
#print(predict('data/dog1.jpeg'))
#print(predict('data/cat1.jpeg'))
#print(predict('data/dog2.jpeg'))
#print(predict('data/cat2.jpeg'))