-
Notifications
You must be signed in to change notification settings - Fork 0
/
backend_tf.py
69 lines (57 loc) · 2.04 KB
/
backend_tf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import onnx
import warnings
warnings.filterwarnings("ignore")
import os
import utils
# 0 = all messages are logged (default behavior)
# 1 = INFO messages are not printed
# 2 = INFO and WARNING messages are not printed
# 3 = INFO, WARNING, and ERROR messages are not printed
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
from onnx_tf.backend import prepare
from onnx_tf.common import supports_device
import tensorflow as tf
import time
import backend
class BackendTensorflow(backend.Backend):
def __init__(self):
super(BackendTensorflow, self).__init__()
self.session = None
self.device = "/device:GPU:0" if supports_device("CUDA") else "/cpu:0"
utils.debug("running on {}".format(self.device))
def name(self):
return "tensorflow"
def version(self):
return tf.__version__
def load(self, model):
utils.debug("loading onnx model {} from disk".format(model.path))
self.onnx_model = onnx.load(model.path)
utils.debug("loaded onnx model")
with tf.device(self.device):
self.model = prepare(self.onnx_model)
utils.debug("prepared onnx model")
self.session = tf.Session(
graph=tf.import_graph_def(
self.model.predict_net.graph.as_graph_def(), name=""
)
)
self.inputs = self.session.graph.get_tensor_by_name(
self.model.predict_net.tensor_dict[
self.model.predict_net.external_input[0]
].name
)
self.outputs = self.session.graph.get_tensor_by_name(
self.model.predict_net.tensor_dict[
self.model.predict_net.external_output[0]
].name
)
utils.debug("loaded onnx model")
def forward_once(self, img):
start = time.time()
result = self.session.run(self.output, {self.inputs: img})
end = time.time() # stop timer
return end - start
def forward(self, img, warmup=True):
if warmup:
self.forward_once(img)
return self.forward_once(img)