-
Notifications
You must be signed in to change notification settings - Fork 3
/
test.py
60 lines (51 loc) · 2.12 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# -*- coding: utf-8 -*-
import sys
import caffe
import cv2
from PIL import Image
import matplotlib
import numpy as np
import lmdb
matplotlib.rcParams['backend'] = "Qt4Agg"
caffe_root = '/home/wathmal/Documents/caffe'
MODEL_FILE = './lenet.prototxt'
PRETRAINED = './model/lenet_iter_600.caffemodel'
sin_labels = ['ක','ඛ', 'ග', 'ඝ', 'ඟ', 'ච', 'ඡ', 'ජ', 'ට', 'ඩ', 'න', 'ණ', 'ත', 'ථ', 'ද', 'ධ', 'ප', 'ඵ', 'බ', 'භ', 'ම', 'ඹ', 'ය', 'ර', 'ල', 'ව', 'ශ', 'ෂ', 'ස','හ','ළ', 'ෆ']
net = caffe.Net(MODEL_FILE, PRETRAINED,caffe.TEST)
caffe.set_mode_cpu()
# Test self-made image
img = caffe.io.load_image('/home/sasithas/IdeaProjects/sinhala-ocr/imgs/unicode/bhashita_complex_5.png', color=False)
img = img.astype(np.uint8)
out = net.forward_all(data=np.asarray([img.transpose(2,0,1)]))
# print out['prob'][0]
sorted_indices = np.flipud(np.argsort(out['prob'][0]))
# print sorted_indices
for x in range(0, 5):
print "label is %s with %.4f prob" % (sin_labels[sorted_indices[x]], out['prob'][0][sorted_indices[x]])
# predicted_label = out['prob'][0].argmax(axis=0)
# print predicted_label
print "predicted label: %d [ %s ]" % (sorted_indices[0], sin_labels[sorted_indices[0]])
# db_path = './sin_test_lmdb'
# lmdb_env = lmdb.open(db_path)
# lmdb_txn = lmdb_env.begin()
# lmdb_cursor = lmdb_txn.cursor()
# count = 0
# correct = 0
# for key, value in lmdb_cursor:
# # print "Count:"
# # print count
# count = count + 1
# datum = caffe.proto.caffe_pb2.Datum()
# datum.ParseFromString(value)
# label = int(datum.label)
# image = np.zeros((datum.channels, datum.height, datum.width))
# image = caffe.io.datum_to_array(datum)
# image = image.transpose()
# out = net.forward_all(data=np.asarray([image.transpose(2,0,1)]))
# predicted_label = out['prob'][0].argmax(axis=0)
# # print out['prob']
# # print predicted_label
# if label == predicted_label:
# correct = correct + 1
# print("Label is class " + str(label) + ", predicted class is " + str(predicted_label))
# print(str(correct) + " out of " + str(count) + " were classified correctly")