-
Notifications
You must be signed in to change notification settings - Fork 2
/
tracing_utils.py
36 lines (32 loc) · 1.14 KB
/
tracing_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import sys
import numpy as np
class LossHistory(object):
def __init__(self, name, X_test, y_test):
self.name = name
self.X_test = X_test
self.y_test = y_test
self.hamming_loss = []
self.crm_loss = []
self.betas = []
self.n_samples = []
self.n_actions = []
self.rewards = []
def update(self, model, crm_dataset):
self.betas += [model.beta]
self.hamming_loss += [model.expected_hamming_loss(self.X_test, self.y_test)]
self.crm_loss += [model.crm_loss(crm_dataset)]
self.n_samples += [len(crm_dataset)]
self.n_actions += [np.sum(crm_dataset.actions)]
self.rewards += [np.sum(crm_dataset.rewards)]
def show_last(self):
print(
'<', self.name,
'| Ham. loss: %.5f' % self.hamming_loss[-1],
'| CRM loss: %.5f' % self.crm_loss[-1],
'|beta|=%.2f' % np.sqrt((self.betas[-1]**2).sum()),
'n=%d' % self.n_samples[-1],
'|A|=%d' % self.n_actions[-1],
'|R|=%d' % self.rewards[-1],
'>',
file=sys.stderr
)