forked from hse-aml/intro-to-dl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
keras_utils.py
45 lines (38 loc) · 1.48 KB
/
keras_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import keras
import tqdm
from collections import defaultdict
import numpy as np
class TqdmProgressCallback(keras.callbacks.Callback):
def on_train_begin(self, logs=None):
self.epochs = self.params['epochs']
def on_epoch_begin(self, epoch, logs=None):
print('Epoch %d/%d' % (epoch + 1, self.epochs))
if "steps" in self.params:
self.use_steps = True
self.target = self.params['steps']
else:
self.use_steps = False
self.target = self.params['samples']
self.prog_bar = tqdm.tqdm_notebook(total=self.target)
self.log_values_by_metric = defaultdict(list)
def _set_prog_bar_desc(self, logs):
for k in self.params['metrics']:
if k in logs:
self.log_values_by_metric[k].append(logs[k])
desc = "; ".join("{0}: {1:.3f}".format(k, np.mean(values)) for k, values in self.log_values_by_metric.items())
self.prog_bar.set_description(desc)
def on_batch_end(self, batch, logs=None):
logs = logs or {}
if self.use_steps:
self.prog_bar.update(1)
else:
batch_size = logs.get('size', 0)
self.prog_bar.update(batch_size)
self._set_prog_bar_desc(logs)
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self._set_prog_bar_desc(logs)
self.prog_bar.update(1) # workaround to show description
self.prog_bar.close()