-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
79 lines (64 loc) · 2.2 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pickle as pkl
def pad_(X, max_atoms, dim=1):
extra = max_atoms - X.shape[0]
return F.pad(X, (0, extra) * dim).clone()
def pad_Dhat(Dh, max_atoms):
extra = max_atoms - Dh.shape[0]
return F.pad(Dh, (0, 0, 0, extra, 0, extra)).clone()
def create_dummy(num_atoms, total_atoms):
Z = torch.LongTensor(num_atoms).random_(total_atoms)
D = torch.rand((num_atoms, num_atoms))
D.masked_fill_(torch.eye(num_atoms).bool(), 0)
return Z + 1, (D + D.T) / 2
def create_dummy_batch(min_atoms, max_atoms, total_atoms, bs):
Zs, Ds, sizes = [], [], []
for num_atoms in torch.randint(min_atoms, max_atoms, (bs,)):
Z, D = create_dummy(num_atoms.item(), total_atoms)
Zs.append(pad_(Z, max_atoms))
Ds.append(pad_(D, max_atoms, 2))
sizes.append(num_atoms)
Zs = torch.stack(Zs)
Ds = torch.stack(Ds)
return Zs, Ds, torch.LongTensor(sizes)
def transform_D(D, sz):
shape = list(D.shape) + [sz]
return D.unsqueeze(-1).expand(shape)
def create_mask(method):
def fn(sizes, full_size):
masks = []
for size in sizes:
masks.append(method(size.item(), full_size))
return torch.stack(masks) #if len(sizes) > 1 else mask
return fn
@create_mask
def mask_2d(size, full_size):
mask = torch.zeros((full_size, full_size))
mask[np.diag_indices(size)] = 1
mask[:size, :size] -= 1
mask.abs()
return mask
@create_mask
def mask_1d(size, full_size):
mask = torch.zeros((full_size,))
mask[:size] = 1
return mask
def read_raw_data():
with open('data/preprocessed.pkl', 'rb') as f:
data = pkl.load(f)
return data
def process_data(data, max_atoms):
res = {}
for smile, (Z, D) in data.items():
res[smile] = (pad_(torch.LongTensor(Z), max_atoms),
pad_(torch.FloatTensor(D), max_atoms, 2),
len(Z))
return res
def create_random_split(n, train_val=(.8, .1)):
idx = np.arange(n)
np.random.shuffle(idx)
train, val, test = np.split(idx, [int(train_val[0] * n), int(sum(train_val) * n)])
return {'train': train, 'val': val, 'test': test}