-
Notifications
You must be signed in to change notification settings - Fork 0
/
step3_generateTFRecords.py
104 lines (85 loc) · 3.1 KB
/
step3_generateTFRecords.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import sys
import glob
import uuid
import shutil
import csv
from libs import tfrg
def loadCSV(filename):
with open(filename, 'rU') as infile:
# read the file as a dictionary for each row ({header : value})
reader = csv.DictReader(infile)
data = {}
for row in reader:
for header, value in row.items():
try:
data[header].append(value)
except KeyError:
data[header] = [value]
return data
def getUniqueFilelist(filename):
data = loadCSV(filename)
return list(set(data['filename']))
def concatCSV(filelist, output):
header_saved = False
with open(output,'w') as fout:
for filename in filelist:
with open(filename) as fin:
header = next(fin)
if not header_saved:
fout.write(header)
header_saved = True
for line in fin:
fout.write(line)
def cleanup(session):
shutil.rmtree("tmp/{}".format(session))
def main():
print("----------------------------")
print(" TFRecord generator v0.1 ")
print("By Giovanni Cimolin da Silva")
print("----------------------------")
dataset = sys.argv[1]
session = str(uuid.uuid4())
print("Session UID: {}".format(session))
# Set up working directories and paths
owd = os.getcwd() # Save directory for return point
print("Creating directories...")
# Create tmp dir
try:
os.makedirs("tmp/{}/images".format(session))
except OSError as e:
pass
# Create output dir
try:
os.makedirs("datasets/{}/TFRecords".format(dataset))
except OSError as e:
pass
print("Locating csv label files...")
# Find all csv generated from the step before
train_labels = glob.glob("datasets/{}/train_labels.csv".format(dataset), recursive=True)
test_labels = glob.glob("datasets/{}/test_labels.csv".format(dataset), recursive=True)
# Concatenate all train and test labels
concatCSV(train_labels, "tmp/{}/train_labels.csv".format(session))
concatCSV(test_labels, "tmp/{}/test_labels.csv".format(session))
# Find and copy all the necessary files to the images folder
print("Copying necessary files...")
images = getUniqueFilelist("tmp/{}/train_labels.csv".format(session)) +\
getUniqueFilelist("tmp/{}/test_labels.csv".format(session))
for image in images:
for filename in glob.iglob("datasets/{}/images/{}".format(dataset, image), recursive=True):
shutil.copy2(filename, "tmp/{}/images/".format(session))
# Open working directory and run TFRecord Generator
print("Creating TFRecord files...")
input("Press Enter to continue...")
os.chdir("tmp/{}/".format(session))
for item in ["train", "test"]:
tfrg.main(
"{}_labels.csv".format(item),
"../../datasets/{}/TFRecords/{}.record".format(dataset,item)
)
# Finishing
os.chdir(owd)
print("Successfully created TFRecord files.")
print("Cleaning up...")
#cleanup(session)
main()