-
Notifications
You must be signed in to change notification settings - Fork 9
/
utils.py
98 lines (70 loc) · 2.56 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from datetime import datetime
import logging
import os
import sys
import torch
# from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter
def interleave(x, bt):
s = list(x.shape)
return torch.reshape(torch.transpose(x.reshape([-1, bt] + s[1:]), 1, 0), [-1] + s[1:])
def de_interleave(x, bt):
s = list(x.shape)
return torch.reshape(torch.transpose(x.reshape([bt, -1] + s[1:]), 1, 0), [-1] + s[1:])
def setup_default_logging(args, default_level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s"):
output_dir = os.path.join(args.dataset, f'x{args.n_labeled}')
os.makedirs(output_dir, exist_ok=True)
writer = SummaryWriter(comment=f'{args.dataset}_{args.n_labeled}')
logger = logging.getLogger('train')
logging.basicConfig( # unlike the root logger, a custom logger can’t be configured using basicConfig()
filename=os.path.join(output_dir, f'{time_str()}.log'),
format=format,
datefmt="%m/%d/%Y %H:%M:%S",
level=default_level)
# print
# file_handler = logging.FileHandler()
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(default_level)
console_handler.setFormatter(logging.Formatter(format))
logger.addHandler(console_handler)
return logger, writer
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, largest=True, sorted=True) # return value, indices
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class AverageMeter(object):
"""
Computes and stores the average and current value
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
# self.avg = self.sum / (self.count + 1e-20)
self.avg = self.sum / self.count
def time_str(fmt=None):
if fmt is None:
fmt = '%Y-%m-%d_%H:%M:%S'
# time.strftime(format[, t])
return datetime.today().strftime(fmt)
if __name__ == '__main__':
a = torch.tensor(range(30))
a_ = interleave(a, 15)
a__ = de_interleave(a_, 15)
print(a, a_, a__)