-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
67 lines (57 loc) · 1.78 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import argparse
import random
import numpy as np
import torch
def all_seed(seed):
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def boolean_string(s):
if s not in {'False', 'True'}:
raise ValueError('Not a valid boolean string')
return s == 'True'
def cosine_similarity(x1, x2=None, eps=1e-8):
x2 = x1 if x2 is None else x2
w1 = x1.norm(p=2, dim=1, keepdim=True)
w2 = w1 if x2 is x1 else x2.norm(p=2, dim=1, keepdim=True)
sim = torch.mm(x1, x2.t())/(w1 * w2.t()).clamp(min=eps)
return sim
gpu = {-1:'cpu', 0:'cuda:0', 1:'cuda:1'}
class AverageMeter(object):
"""compute and store the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.sum = 0
self.count = 0
def update(self, val, n):
self.sum += val*n
self.count += n
def avg(self):
if self.count == 0:
return 0
return float(self.sum) / self.count
class EarlyStopping():
def __init__(self, min_delta, patience, cumulative_delta):
self.min_delta = min_delta
self.patience = patience
self.cumulative_delta = cumulative_delta
self.counter = 0
self.best_score = None
def step(self, score):
if self.best_score is None:
self.best_score = score
elif score <= self.best_score + self.min_delta:
self.counter += 1
if self.counter >= self.patience:
return True
else:
self.best_score = score
self.counter = 0
return False
def reset(self):
self.counter = 0
self.best_score = None