-
Notifications
You must be signed in to change notification settings - Fork 2
/
write_totfrec_val.py
67 lines (53 loc) · 2.53 KB
/
write_totfrec_val.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import six
import os
import pickle
import numpy as np
import tensorflow as tf
DATA_DIR = './images_all/images/'
OUTPUT_DIRECTORY = './tf_records/'
IMAGES_PER_RECORD_SHARD = 256
NAME = 'val'
def _int64_feature(value):
"""Wrapper for inserting int64 features into Example proto."""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _bytes_feature(value):
"""Wrapper for inserting bytes features into Example proto."""
if isinstance(value, six.string_types):
value = six.binary_type(value,'utf-8')
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _convert_to_example(image, label, filename):
example = tf.train.Example(features=tf.train.Features(feature={
'image/encoded': _bytes_feature(image),
'image/label': _int64_feature(label),
'image/filename': _bytes_feature(filename)}))
return example
def main():
with open('./validation_labels_new.pkl','rb') as f:
labels = pickle.load(f)
filenames = list(labels.keys())
num_files = len(filenames)
num_shards = num_files//IMAGES_PER_RECORD_SHARD
current_file = 0
#print(num_files)
for shard_num in range(num_shards):
output_file = '%s-%.5d-of-%.5d' % (NAME, shard_num, num_shards)
output_file = os.path.join(OUTPUT_DIRECTORY, output_file)
with tf.python_io.TFRecordWriter(output_file) as writer:
for record_num in range(IMAGES_PER_RECORD_SHARD):
with tf.gfile.FastGFile(os.path.join(DATA_DIR,filenames[current_file]),'rb') as f:
image_data=f.read()
label = labels[filenames[current_file]]
example = _convert_to_example(image_data, label,filenames[current_file])
print("===========================================================")
print("Filename = %s"%(filenames[current_file]))
print("Label Written = %s"%(label))
print("Pickle label = %s"%(labels[filenames[current_file]]))
print("===========================================================")
writer.write(example.SerializeToString())
current_file+=1
print('Finished writing shard %d of %d' % (shard_num, num_shards))
print("Used %s files of %s"%(current_file, num_files))
if __name__ == '__main__':
main()