-
Notifications
You must be signed in to change notification settings - Fork 187
/
config.py
59 lines (53 loc) · 1.82 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# Copyright (c) Facebook, Inc. and its affiliates.
import json
DEFAULTS = {
"network": {
"arch": "resnet101",
"activation": "relu", # supported: relu, leaky_relu, elu, identity
"activation_param": 0.01, # slope for leaky_relu, alpha for elu
"input_3x3": False,
"bn_mode": "standard", # supported: standard, inplace, sync
"classes": 1000,
"dilation": 1,
"weight_gain_multiplier": 1, # note: this is ignored if weight_init == kaiming_*
"weight_init": "xavier_normal", # supported: xavier_[normal,uniform], kaiming_[normal,uniform], orthogonal
},
"optimizer": {
"batch_size": 256,
"type": "SGD", # supported: SGD, Adam
"momentum": 0.9,
"weight_decay": 1e-4,
"clip": 1.0,
"learning_rate": 0.1,
"classifier_lr": -1.0, # If -1 use same learning rate as the rest of the network
"nesterov": False,
"schedule": {
"type": "constant", # supported: constant, step, multistep, exponential, linear
"mode": "epoch", # supported: epoch, step
"epochs": 10,
"params": {},
},
},
"input": {
"scale_train": -1, # If -1 do not scale
"crop_train": 224,
"color_jitter_train": False,
"lighting_train": False,
"scale_val": 256, # If -1 do not scale
"crop_val": 224,
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
},
}
def _merge(src, dst):
for k, v in src.items():
if k in dst:
if isinstance(v, dict):
_merge(src[k], dst[k])
else:
dst[k] = v
def load_config(config_file, defaults=DEFAULTS):
with open(config_file, "r") as fd:
config = json.load(fd)
_merge(defaults, config)
return config