-
Notifications
You must be signed in to change notification settings - Fork 0
/
tf2_validateModels.py
52 lines (42 loc) · 1.91 KB
/
tf2_validateModels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from tabnanny import verbose
import numpy as np
import cv2
import os
import argparse
import time
import preprocess
import sys
import random
from myUtils import DatasetGenerator
import tensorflow as tf
from tensorflow_model_optimization.quantization.keras import vitis_quantize
# def load_and_preprocess(imageSize): # Creare il dataset con anche le label, usa la classe Dataset già per leggere i file
# images = os.listdir(os.path.join("calibration_images","ILSVRC2012_img_val"))
# random.shuffle(images)
# images = images[:1000]
# preprocessed_images = []
# print("\nStart preprocessing the calibration images\n")
# for i in images:
# img = cv2.imread(os.path.join("calibration_images","ILSVRC2012_img_val",i))
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# img = preprocess.central_crop(img, imageSize, imageSize)
# img = preprocess.tf_imagenet_preprocess(img)
# # img = np.expand_dims(img, 0)
# if img.shape == (224, 224, 3):
# preprocessed_images.append(img)
# print("\nEnd preprocessing the calibration images\n")
# # return np.array(preprocessed_images)
# return tf.data.Dataset.from_tensor_slices(preprocessed_images)
# random.seed(42)
# calibrationDataset = load_and_preprocess(224)
# batched_dataset = calibrationDataset.batch(32, drop_remainder=True)
print("\nStart make dataset")
datasetGenerator = DatasetGenerator(batch_size=32, startImageNumber=1024, stopImageNumber=2048, width=224, height=224)
batched_dataset = datasetGenerator.make_dataset()
print("\nStop make dataset")
with vitis_quantize.quantize_scope():
quantized_model = tf.keras.models.load_model(os.path.join("tf2_vai_quant_models","my_mobilenet",'quantized_my_mobilenet.h5'))
quantized_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics= tf.keras.metrics.SparseTopKCategoricalAccuracy())
quantized_model.evaluate(batched_dataset, verbose=2)