-
Notifications
You must be signed in to change notification settings - Fork 7
/
EdgeGCN.py
90 lines (71 loc) · 5.12 KB
/
EdgeGCN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
'''
EdgeGCN used for SGGpoint (Chaoyi Zhang)
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter
from torch_geometric.nn import GCNConv
##################################################
# #
# #
# Core Network: EdgeGCN #
# #
# #
##################################################
class EdgeGCN(torch.nn.Module):
def __init__(self, num_node_in_embeddings, num_edge_in_embeddings, AttnEdgeFlag, AttnNodeFlag):
super(EdgeGCN, self).__init__()
self.node_GConv1 = GCNConv(num_node_in_embeddings, num_node_in_embeddings // 2, add_self_loops=True)
self.node_GConv2 = GCNConv(num_node_in_embeddings // 2, num_node_in_embeddings, add_self_loops=True)
self.edge_MLP1 = nn.Sequential(nn.Conv1d(num_edge_in_embeddings, num_edge_in_embeddings // 2, 1), nn.ReLU())
self.edge_MLP2 = nn.Sequential(nn.Conv1d(num_edge_in_embeddings // 2, num_edge_in_embeddings, 1), nn.ReLU())
self.AttnEdgeFlag = AttnEdgeFlag # boolean (for ablaiton studies)
self.AttnNodeFlag = AttnNodeFlag # boolean (for ablaiton studies)
# multi-dimentional (N-Dim) node/edge attn coefficients mappings
self.edge_attentionND = nn.Linear(num_edge_in_embeddings, num_node_in_embeddings // 2) if self.AttnEdgeFlag else None
self.node_attentionND = nn.Linear(num_node_in_embeddings, num_edge_in_embeddings // 2) if self.AttnNodeFlag else None
self.node_indicator_reduction = nn.Linear(num_edge_in_embeddings, num_edge_in_embeddings // 2) if self.AttnNodeFlag else None
def concate_NodeIndicator_for_edges(self, node_indicator, batchwise_edge_index):
node_indicator = node_indicator.squeeze(0)
edge_index_list = batchwise_edge_index.t()
subject_idx_list = edge_index_list[:, 0]
object_idx_list = edge_index_list[:, 1]
subject_indicator = node_indicator[subject_idx_list] # (num_edges, num_mid_channels)
object_indicator = node_indicator[object_idx_list] # (num_edges, num_mid_channels)
edge_concat = torch.cat((subject_indicator, object_indicator), dim=1)
return edge_concat # (num_edges, num_mid_channels * 2)
def forward(self, node_data, edge_feats):
# prepare node_feats & edge_feats in the following formats
# node_feats: (1, num_nodes, num_embeddings)
# edge_feats: (1, num_edges, num_embeddings)
# (num_embeddings = num_node_in_embeddings = num_edge_in_embeddings) = 2 * num_mid_channels
node_feats, edge_index = node_data.x, node_data.edge_index
#### Deriving Edge Attention
if self.AttnEdgeFlag:
edge_indicator = self.edge_attentionND(edge_feats.squeeze(0)).unsqueeze(0).permute(0, 2, 1) # (1, num_mid_channels, num_edges)
raw_out_row = scatter(edge_indicator, edge_index.t()[:, 0].squeeze(0), dim=2, reduce='mean', dim_size=node_feats.size(0)) # (1, num_mid_channels, num_nodes)
raw_out_col = scatter(edge_indicator, edge_index.t()[:, 1].squeeze(0), dim=2, reduce='mean', dim_size=node_feats.size(0)) # (1, num_mid_channels, num_nodes)
agg_edge_indicator_logits = raw_out_row * raw_out_col # (1, num_mid_channels, num_nodes)
agg_edge_indicator = torch.sigmoid(agg_edge_indicator_logits).permute(0, 2, 1).squeeze(0) # (num_nodes, num_mid_channels)
else:
agg_edge_indicator = 1
#### Node Evolution Stream (NodeGCN)
node_feats = F.relu(self.node_GConv1(node_feats, edge_index)) * agg_edge_indicator # applying EdgeAttn on Nodes
node_feats = F.dropout(node_feats, training=self.training)
node_feats = F.relu(self.node_GConv2(node_feats, edge_index))
node_feats = node_feats.unsqueeze(0) # (1, num_nodes, num_embeddings)
#### Deriving Node Attention
if self.AttnNodeFlag:
node_indicator = F.relu(self.node_attentionND(node_feats.squeeze(0)).unsqueeze(0)) # (1, num_mid_channels, num_nodes)
agg_node_indicator = self.concate_NodeIndicator_for_edges(node_indicator, edge_index) # (num_edges, num_mid_channels * 2)
agg_node_indicator = self.node_indicator_reduction(agg_node_indicator).unsqueeze(0).permute(0,2,1) # (1, num_mid_channels, num_edges)
agg_node_indicator = torch.sigmoid(agg_node_indicator) # (1, num_mid_channels, num_edges)
else:
agg_node_indicator = 1
#### Edge Evolution Stream (EdgeMLP)
edge_feats = edge_feats.permute(0, 2, 1) # (1, num_embeddings, num_edges)
edge_feats = self.edge_MLP1(edge_feats) # (1, num_mid_channels, num_edges)
edge_feats = F.dropout(edge_feats, training=self.training) * agg_node_indicator # applying NodeAttn on Edges
edge_feats = self.edge_MLP2(edge_feats).permute(0, 2, 1) # (1, num_edges, num_embeddings)
return {'node_feats': node_feats, 'edge_feats': edge_feats}