-
Notifications
You must be signed in to change notification settings - Fork 1
/
argparse_example.py
140 lines (123 loc) · 4.34 KB
/
argparse_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import argparse
import pathlib
import sys
import pandas as pd
import SimpleITK as sitk
# definitions of argparse types, enables argparse to validate the command line parameters
def file_path(path):
p = pathlib.Path(path)
if p.is_file():
return p
else:
raise argparse.ArgumentTypeError(
f"Invalid argument ({path}), not a file path or file does not exist."
)
def dir_path(path):
p = pathlib.Path(path)
if p.is_dir():
return p
else:
raise argparse.ArgumentTypeError(
f"Invalid argument ({path}), not a directory path or directory does not exist."
)
def nonnegative_int(i):
res = int(i)
if res >= 0:
return res
else:
raise argparse.ArgumentTypeError(
f"Invalid argument ({i}), expected value >= 0 ."
)
def positive_int(i):
res = int(i)
if res > 0:
return res
else:
raise argparse.ArgumentTypeError(
f"Invalid argument ({i}), expected value > 0 ."
)
def modality_dir_path(path, modality):
"""
Check that directory contains a DICOM series, and the first one is for the specific modality (e.g. 'CT')
"""
p = pathlib.Path(path)
if p.is_dir():
reader = sitk.ImageFileReader()
reader.SetFileName(sitk.ImageSeriesReader_GetGDCMSeriesFileNames(str(path))[0])
reader.ReadImageInformation()
if reader.GetMetaData("0008|0060").strip() == modality:
return p
else:
raise argparse.ArgumentTypeError(
f"Invalid argument ({path}), first series in directory is not {modality} modality."
)
else:
raise argparse.ArgumentTypeError(f"Invalid argument ({path}), not a directory.")
def csv_path(path, required_columns={}):
"""
Define the csv_path type for use with argparse. Checks
that the given path string is a path to a csv file and that the
header of the csv file contains the required columns.
"""
p = pathlib.Path(path)
required_columns = set(required_columns)
if p.is_file():
try: # only read the csv header
expected_columns_exist = required_columns.issubset(
set(pd.read_csv(path, nrows=0).columns.tolist())
)
if expected_columns_exist:
return p
else:
raise argparse.ArgumentTypeError(
f"Invalid argument ({path}), does not contain all expected columns."
)
except UnicodeDecodeError:
raise argparse.ArgumentTypeError(
f"Invalid argument ({path}), not a csv file."
)
else:
raise argparse.ArgumentTypeError(f"Invalid argument ({path}), not a file.")
def main(argv=None):
parser = argparse.ArgumentParser(description="Argparse usage example")
# positional/required argument
parser.add_argument("input_data_file", type=file_path)
parser.add_argument(
"cxr_dir",
type=lambda x: modality_dir_path(x, "CR"),
help="path to Computed Radiography, chest x-ray, directory",
)
# optional arguments (starting with --)
parser.add_argument("--gpu_id", type=nonnegative_int, default=0)
parser.add_argument("--batch_size", type=positive_int, default=8)
parser.add_argument("--epochs", type=positive_int, default=100)
parser.add_argument("--lr", type=float, default=0.0001)
# use nargs to tell argparse that if the exclude_label flag is given
# it expects n=N values afterwards (use * for non-fixed number of values)
# collected into a list.
parser.add_argument("--exclude_label", type=str, nargs="*")
args = parser.parse_args(argv)
print(args)
if __name__ == "__main__":
# for debugging
sys.exit(
main(
[
"data/my_data.csv",
"data",
"--gpu_id",
"1",
"--batch_size",
"8",
"--epochs",
"100",
"--lr",
"0.0001",
"--exclude_label",
"Cardiomegaly",
"Pneumonia",
]
)
)
# for running the program
# sys.exit(main(sys.argv[1:]))