-
Notifications
You must be signed in to change notification settings - Fork 3
/
dataset.py
executable file
·70 lines (59 loc) · 2.29 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/env python3
import torch
import numpy as np
import time
import dgl
class QGTC_dataset(torch.nn.Module):
"""
data loading for more graphs
"""
def __init__(self, path, dim, num_class):
super(QGTC_dataset, self).__init__()
self.nodes = set()
self.num_nodes = 0
self.g = dgl.DGLGraph()
self.num_features = dim
self.num_classes = num_class
self.init_edges(path)
self.init_embedding(dim)
self.init_labels(num_class)
train = 1
val = 0.3
test = 0.1
self.train_mask = [1] * int(self.num_nodes * train) + [0] * (self.num_nodes - int(self.num_nodes * train))
self.val_mask = [1] * int(self.num_nodes * val)+ [0] * (self.num_nodes - int(self.num_nodes * val))
self.test_mask = [1] * int(self.num_nodes * test) + [0] * (self.num_nodes - int(self.num_nodes * test))
self.train_mask = torch.BoolTensor(self.train_mask)
self.val_mask = torch.BoolTensor(self.val_mask)
self.test_mask = torch.BoolTensor(self.test_mask)
def init_edges(self, path):
# fp = open(path, "r")
# src_li = []
# dst_li = []
start = time.perf_counter()
# for line in fp:
# src, dst = line.strip('\n').split()
# src, dst = int(src), int(dst)
# src_li.append(src)
# dst_li.append(dst)
# self.nodes.add(src)
# self.nodes.add(dst)
start = time.perf_counter()
graph_obj = np.load(path)
src_li = graph_obj['src_li']
dst_li = graph_obj['dst_li']
self.edge_index = np.array([src_li, dst_li])
self.g.add_edges(src_li, dst_li)
# self.num_nodes = max(src_li + dst_li) + 1
# self.num_nodes = graph_obj['num_nodes']
self.num_nodes = self.g.number_of_nodes()
# print(self.num_nodes)
dur = time.perf_counter() - start
print("Loading (ms):\t{:.3f}".format(dur*1e3))
def init_embedding(self, dim):
self.x = torch.randn(self.num_nodes, dim).cuda()
self.g.ndata['feat'] = torch.randn(self.num_nodes, dim)
def init_labels(self, num_class):
self.y = torch.LongTensor([1] * self.num_nodes)
def forward(*input, **kwargs):
pass