-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
51 lines (41 loc) · 1.63 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#!/usr/env/bin python3
# -*- coding: utf-8 -*-
import argparse
import numpy as np
import sys
import subprocess
import os
import yaml
import chainer
from chainer import cuda, optimizers, serializers
from chainer import training
from chainercv.links import PixelwiseSoftmaxClassifier
from erfnet.config_utils import *
chainer.cuda.set_max_workspace_size(1024 * 1024 * 1024)
os.environ["CHAINER_TYPE_CHECK"] = "0"
from collections import OrderedDict
yaml.add_constructor(yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG,
lambda loader, node: OrderedDict(loader.construct_pairs(node)))
from erfnet.models import erfnet_paper
def train_erfnet():
"""Training ERFNet."""
chainer.config.debug = True
config = parse_args()
train_data, test_data = load_dataset(config["dataset"])
train_iter, test_iter = create_iterator(train_data, test_data, config['iterator'])
model = get_model(config["model"])
class_weight = get_class_weight(config)
model = PixelwiseSoftmaxClassifier(model, class_weight=class_weight)
optimizer = create_optimizer(config['optimizer'], model)
devices = parse_devices(config['gpus'])
updater = create_updater(train_iter, optimizer, config['updater'], devices)
trainer = training.Trainer(updater, config['end_trigger'], out=config['results'])
trainer = create_extension(trainer, test_iter, model.predictor,
config['extension'], devices=devices)
trainer.run()
chainer.serializers.save_npz(os.path.join(config['results'], 'model.npz'),
model.predictor)
def main():
train_erfnet()
if __name__ == '__main__':
main()