-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
40 lines (33 loc) · 1.12 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import make_data
from network import my_network
import os
from tensorflow.keras import losses
from settings_tf import config
# 指定GPU训练
os.environ['CUDA_VISIBLE_DEVICES']='2, 3, 4'
train_images, train_labels = make_data.make()
# 训练次数
epochs = 1000
# 分批大小
batch_size = 16
def run():
model = my_network.MyModel(output=config.LABEL_SIZE)
model.build(input_shape=((batch_size,) + config.IMAGE_SIZE ))
# opti = optimizers.Adam(lr=0.00001)
model.summary()
loss = losses.MeanSquaredLogarithmicError()
model.compile(optimizer='adam', loss=loss, metrics=['accuracy'])
history = model.fit(train_images, train_labels, epochs=epochs, batch_size=batch_size, validation_split=config.train_and_val)
loss = history.history['loss']
accuracy = history.history['accuracy']
with open('./result.txt', 'w')as f:
f.write(str(accuracy))
f.write('\n')
f.write(str(loss))
f.write('\n')
f.write(str(history.history['val_loss']))
f.write('\n')
f.write(str(history.history['val_accuracy']))
# print(history)
model.save('./model/exp1')
run()