-
Notifications
You must be signed in to change notification settings - Fork 306
/
create_pascal_tf_record.py
144 lines (110 loc) · 5.13 KB
/
create_pascal_tf_record.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""Converts PASCAL dataset to TFRecords file format."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import io
import os
import sys
import PIL.Image
import tensorflow as tf
from utils import dataset_util
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='./dataset/VOCdevkit/VOC2012',
help='Path to the directory containing the PASCAL VOC data.')
parser.add_argument('--output_path', type=str, default='./dataset',
help='Path to the directory to create TFRecords outputs.')
parser.add_argument('--train_data_list', type=str, default='./dataset/train.txt',
help='Path to the file listing the training data.')
parser.add_argument('--valid_data_list', type=str, default='./dataset/val.txt',
help='Path to the file listing the validation data.')
parser.add_argument('--image_data_dir', type=str, default='JPEGImages',
help='The directory containing the image data.')
parser.add_argument('--label_data_dir', type=str, default='SegmentationClassAug',
help='The directory containing the augmented label data.')
def dict_to_tf_example(image_path,
label_path):
"""Convert image and label to tf.Example proto.
Args:
image_path: Path to a single PASCAL image.
label_path: Path to its corresponding label.
Returns:
example: The converted tf.Example.
Raises:
ValueError: if the image pointed to by image_path is not a valid JPEG or
if the label pointed to by label_path is not a valid PNG or
if the size of image does not match with that of label.
"""
with tf.gfile.GFile(image_path, 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = PIL.Image.open(encoded_jpg_io)
if image.format != 'JPEG':
raise ValueError('Image format not JPEG')
with tf.gfile.GFile(label_path, 'rb') as fid:
encoded_label = fid.read()
encoded_label_io = io.BytesIO(encoded_label)
label = PIL.Image.open(encoded_label_io)
if label.format != 'PNG':
raise ValueError('Label format not PNG')
if image.size != label.size:
raise ValueError('The size of image does not match with that of label.')
width, height = image.size
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
'label/encoded': dataset_util.bytes_feature(encoded_label),
'label/format': dataset_util.bytes_feature('png'.encode('utf8')),
}))
return example
def create_tf_record(output_filename,
image_dir,
label_dir,
examples):
"""Creates a TFRecord file from examples.
Args:
output_filename: Path to where output file is saved.
image_dir: Directory where image files are stored.
label_dir: Directory where label files are stored.
examples: Examples to parse and save to tf record.
"""
writer = tf.python_io.TFRecordWriter(output_filename)
for idx, example in enumerate(examples):
if idx % 500 == 0:
tf.logging.info('On image %d of %d', idx, len(examples))
image_path = os.path.join(image_dir, example + '.jpg')
label_path = os.path.join(label_dir, example + '.png')
if not os.path.exists(image_path):
tf.logging.warning('Could not find %s, ignoring example.', image_path)
continue
elif not os.path.exists(label_path):
tf.logging.warning('Could not find %s, ignoring example.', label_path)
continue
try:
tf_example = dict_to_tf_example(image_path, label_path)
writer.write(tf_example.SerializeToString())
except ValueError:
tf.logging.warning('Invalid example: %s, ignoring.', example)
writer.close()
def main(unused_argv):
if not os.path.exists(FLAGS.output_path):
os.makedirs(FLAGS.output_path)
tf.logging.info("Reading from VOC dataset")
image_dir = os.path.join(FLAGS.data_dir, FLAGS.image_data_dir)
label_dir = os.path.join(FLAGS.data_dir, FLAGS.label_data_dir)
if not os.path.isdir(label_dir):
raise ValueError("Missing Augmentation label directory. "
"You may download the augmented labels from the link (Thanks to DrSleep): "
"https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip")
train_examples = dataset_util.read_examples_list(FLAGS.train_data_list)
val_examples = dataset_util.read_examples_list(FLAGS.valid_data_list)
train_output_path = os.path.join(FLAGS.output_path, 'train.record')
val_output_path = os.path.join(FLAGS.output_path, 'val.record')
create_tf_record(train_output_path, image_dir, label_dir, train_examples)
create_tf_record(val_output_path, image_dir, label_dir, val_examples)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)