-
Notifications
You must be signed in to change notification settings - Fork 312
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
341 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import paddle | ||
import paddle.nn as nn | ||
import paddle.nn.functional as F | ||
|
||
from .base_conv import BaseConv | ||
|
||
|
||
def gcn_norm(edge_index, num_nodes): | ||
_, col = edge_index[:, 0], edge_index[:, 1] | ||
degree = paddle.zeros(shape=[num_nodes], dtype="int64") | ||
degree = paddle.scatter( | ||
x=degree, | ||
index=col, | ||
updates=paddle.ones_like( | ||
col, dtype="int64"), | ||
overwrite=False) | ||
norm = paddle.cast(degree, dtype=paddle.get_default_dtype()) | ||
norm = paddle.clip(norm, min=1.0) | ||
norm = paddle.pow(norm, -0.5) | ||
norm = paddle.reshape(norm, [-1, 1]) | ||
return norm | ||
|
||
|
||
class APPNPConv(BaseConv): | ||
def __init__(self, alpha=0.2, k_hop=10, self_loop=False): | ||
super(APPNPConv, self).__init__() | ||
self.alpha = alpha | ||
self.k_hop = k_hop | ||
self.self_loop = self_loop | ||
|
||
def forward(self, edge_index, num_nodes, feature, norm=None): | ||
if self.self_loop: | ||
index = paddle.arange(start=0, end=num_nodes, dtype="int64") | ||
self_loop_edges = paddle.transpose(paddle.stack((index, index)), [1, 0]) | ||
|
||
mask = edge_index[:, 0] != edge_index[:, 1] | ||
mask_index = paddle.masked_select(paddle.arange(end=edge_index.shape[0]), mask) | ||
edges = paddle.gather(edge_index, mask_index) # remove self loop | ||
|
||
edge_index = paddle.concat((self_loop_edges, edges), axis=0) | ||
if norm is None: | ||
norm = gcn_norm(edge_index, num_nodes) | ||
|
||
h0 = feature | ||
|
||
for _ in range(self.k_hop): | ||
feature = feature * norm | ||
feature = self.send_recv(edge_index, feature, "sum") | ||
feature = feature * norm | ||
feature = self.alpha * h0 + (1 - self.alpha) * feature | ||
|
||
return feature | ||
|
||
|
||
class APPNP(nn.Layer): | ||
"""Implement of APPNP""" | ||
|
||
def __init__(self, | ||
input_size, | ||
num_class, | ||
num_layers=1, | ||
hidden_size=64, | ||
dropout=0.5, | ||
k_hop=10, | ||
alpha=0.1, | ||
**kwargs): | ||
super(APPNP, self).__init__() | ||
self.num_class = num_class | ||
self.num_layers = num_layers | ||
self.hidden_size = hidden_size | ||
self.dropout = dropout | ||
self.alpha = alpha | ||
self.k_hop = k_hop | ||
|
||
self.mlps = nn.LayerList() | ||
self.mlps.append(nn.Linear(input_size, self.hidden_size)) | ||
self.drop_fn = nn.Dropout(self.dropout) | ||
for _ in range(self.num_layers - 1): | ||
self.mlps.append(nn.Linear(self.hidden_size, self.hidden_size)) | ||
|
||
self.output = nn.Linear(self.hidden_size, num_class) | ||
self.appnp = APPNPConv(alpha=self.alpha, k_hop=self.k_hop) | ||
|
||
def forward(self, edge_index, num_nodes, feature): | ||
for m in self.mlps: | ||
feature = self.drop_fn(feature) | ||
feature = m(feature) | ||
feature = F.relu(feature) | ||
feature = self.drop_fn(feature) | ||
feature = self.output(feature) | ||
feature = self.appnp(edge_index, num_nodes, feature) | ||
return feature |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
import paddle.nn as nn | ||
import paddle.nn.functional as F | ||
from pgl.math import segment_pool | ||
|
||
from .base_conv import BaseConv | ||
|
||
|
||
class GINConv(BaseConv): | ||
def __init__( | ||
self, input_size, output_size, activation=None, init_eps=0.0, train_eps=False | ||
): | ||
super(GINConv, self).__init__() | ||
self.input_size = input_size | ||
self.output_size = output_size | ||
self.linear1 = nn.Linear(input_size, output_size, bias_attr=True) | ||
self.linear2 = nn.Linear(output_size, output_size, bias_attr=True) | ||
self.layer_norm = nn.LayerNorm(output_size) | ||
if train_eps: | ||
self.epsilon = self.create_parameter( | ||
shape=[1, 1], | ||
dtype="float32", | ||
default_initializer=nn.initializer.Constant(value=init_eps), | ||
) | ||
else: | ||
self.epsilon = init_eps | ||
|
||
if isinstance(activation, str): | ||
activation = getattr(F, activation) | ||
self.activation = activation | ||
|
||
def forward(self, edge_index, feature): | ||
neigh_feature = self.send_recv(edge_index, feature, "sum") | ||
output = neigh_feature + feature * (self.epsilon + 1.0) | ||
|
||
output = self.linear1(output) | ||
output = self.layer_norm(output) | ||
|
||
if self.activation is not None: | ||
output = self.activation(output) | ||
|
||
output = self.linear2(output) | ||
|
||
return output | ||
|
||
|
||
class GIN(nn.Layer): | ||
def __init__(self, input_size, num_class, num_layers, hidden_size, pool_type="sum", dropout_prob=0.0): | ||
super(GIN, self).__init__() | ||
self.input_size = input_size | ||
self.num_class = num_class | ||
self.num_layers = num_layers | ||
self.hidden_size = hidden_size | ||
self.pool_type = pool_type | ||
self.dropout_prob = dropout_prob | ||
|
||
self.gin_convs = nn.LayerList() | ||
self.norms = nn.LayerList() | ||
self.linears = nn.LayerList() | ||
self.linears.append(nn.Linear(self.input_size, self.num_class)) | ||
|
||
for i in range(self.num_layers): | ||
if i == 0: | ||
input_size = self.input_size | ||
else: | ||
input_size = self.hidden_size | ||
gin = GINConv(input_size, self.hidden_size, "relu") | ||
self.gin_convs.append(gin) | ||
ln = paddle.nn.LayerNorm(self.hidden_size) | ||
self.norms.append(ln) | ||
self.linears.append(nn.Linear(self.hidden_size, self.num_class)) | ||
self.relu = nn.ReLU() | ||
|
||
|
||
def forward(self, edge_index, feature): | ||
feature_list = [feature] | ||
for i in range(self.num_layers): | ||
h = self.gin_convs[i](edge_index, feature_list[i]) | ||
h = self.norms[i](h) | ||
h = self.relu(h) | ||
feature_list.append(h) | ||
|
||
output = 0 | ||
for i, h in enumerate(feature_list): | ||
h = F.dropout(h, p=self.dropout_prob) | ||
output += self.linears[i](h) | ||
|
||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import paddle | ||
import paddle.nn as nn | ||
import paddle.nn.functional as F | ||
|
||
from .base_conv import BaseConv | ||
|
||
|
||
def gcn_norm(edge_index, num_nodes): | ||
_, col = edge_index[:, 0], edge_index[:, 1] | ||
degree = paddle.zeros(shape=[num_nodes], dtype="int64") | ||
degree = paddle.scatter( | ||
x=degree, | ||
index=col, | ||
updates=paddle.ones_like( | ||
col, dtype="int64"), | ||
overwrite=False) | ||
norm = paddle.cast(degree, dtype=paddle.get_default_dtype()) | ||
norm = paddle.clip(norm, min=1.0) | ||
norm = paddle.pow(norm, -0.5) | ||
norm = paddle.reshape(norm, [-1, 1]) | ||
return norm | ||
|
||
|
||
class SGCConv(BaseConv): | ||
def __init__(self, input_size, output_size, k_hop=2, cached=True, activation=None, bias=False): | ||
super(SGCConv, self).__init__() | ||
self.input_size = input_size | ||
self.output_size = output_size | ||
self.k_hop = k_hop | ||
self.linear = nn.Linear(input_size, output_size, bias_attr=False) | ||
if bias: | ||
self.bias = self.create_parameter(shape=[output_size], is_bias=True) | ||
|
||
self.cached = cached | ||
self.cached_output = None | ||
if isinstance(activation, str): | ||
activation = getattr(F, activation) | ||
self.activation = activation | ||
|
||
def forward(self, edge_index, num_nodes, feature): | ||
if self.cached: | ||
if self.cached_output is None: | ||
norm = gcn_norm(edge_index, num_nodes) | ||
for hop in range(self.k_hop): | ||
feature = feature * norm | ||
feature = self.send_recv(edge_index, feature, "sum") | ||
feature = feature * norm | ||
self.cached_output = feature | ||
else: | ||
feature = self.cached_output | ||
else: | ||
norm = gcn_norm(edge_index, num_nodes) | ||
for hop in range(self.k_hop): | ||
feature = feature * norm | ||
feature = self.send_recv(edge_index, feature, "sum") | ||
feature = feature * norm | ||
|
||
output = self.linear(feature) | ||
if hasattr(self, "bias"): | ||
output = output + self.bias | ||
|
||
if self.activation is not None: | ||
output = self.activation(output) | ||
return output | ||
|
||
|
||
class SGC(nn.Layer): | ||
def __init__(self, input_size, num_class, num_layers=1, **kwargs): | ||
super(SGC, self).__init__() | ||
self.num_class = num_class | ||
self.num_layers = num_layers | ||
self.sgc_layer = SGCConv( | ||
input_size=input_size, output_size=num_class, k_hop=num_layers) | ||
|
||
def forward(self, edge_index, num_nodes, feature): | ||
feature = self.sgc_layer(edge_index, num_nodes, feature) | ||
return feature |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.