-
Notifications
You must be signed in to change notification settings - Fork 2
/
herbarium_phenology_dnn.py
93 lines (84 loc) · 2.97 KB
/
herbarium_phenology_dnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import argparse
from training import train_command
from prediction import predict_command
if __name__ == '__main__':
# Parse command line arguments
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
'--dataset_root', type=str, required=True, default=argparse.SUPPRESS,
help='path to datasets'
)
parser.add_argument(
'--task', type=str, required=True, default=argparse.SUPPRESS,
choices=['fertility', 'flower/fruit', 'phenophase'],
help='which task to use for biggest dataset'
)
parser.add_argument(
'--subset', type=str,
choices=[
'train', 'test', 'random_test', 'species_test', 'herbarium_test'
],
required=True, default=argparse.SUPPRESS,
help='which subset to use'
)
parser.add_argument(
'--batch_size', type=int, required=True, default=argparse.SUPPRESS,
help='training batch size'
)
parser.add_argument(
'--keep_image_ratio', action='store_true',
help='image preprocessing that preserves the image ratio'
)
parser.add_argument(
'--downsample_image', action='store_true',
help='image preprocessing that downsamples the image by a '
'factor of 2'
)
parser.add_argument(
'--num_workers', type=int, default=8,
help='number of jobs for data loading'
)
subparsers = parser.add_subparsers(
help='action: train or predict',
)
# Subparser for training
parser_train = subparsers.add_parser(
'train', help='perform training',
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser_train.add_argument('experiment_output_path')
parser_train.add_argument(
'--model', type=str, default='resnet50',
help='model to finetune'
)
parser_train.add_argument(
'--num_epochs', type=int, required=True, default=argparse.SUPPRESS,
help='max number of epochs for training'
)
parser_train.add_argument(
'--lr', type=float, required=True, default=argparse.SUPPRESS,
help='learning rate'
)
parser_train.add_argument(
'--lr_decay', type=str, default=None,
help='use multistep lr decay, pass a string containing the milestones,'
' e.g. "[1./3, 2./3]"'
)
parser_train.add_argument(
'--data_augmentation', action='store_true',
help='data augmentation to use during training'
)
parser_train.set_defaults(func=train_command)
# Subparser for prediction
parser_predict = subparsers.add_parser(
'predict', help='predict on val/test',
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser_predict.add_argument('model_file')
parser_predict.add_argument('output_predictions_file')
parser_predict.set_defaults(func=predict_command)
args = parser.parse_args()
# Delegate to action handler
args.func(args)