Skip to content

Commit

Permalink
add appnp, gat, gin
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Sep 23, 2024
1 parent 6dbb47c commit c5a9592
Show file tree
Hide file tree
Showing 8 changed files with 341 additions and 6 deletions.
3 changes: 3 additions & 0 deletions examples/gnn_depoly/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .appnp import APPNP
from .gcn import GCN
from .gat import GAT
from .gin import GIN
from .graphsage import GraphSage
from .sgc import SGC
107 changes: 107 additions & 0 deletions examples/gnn_depoly/models/appnp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import paddle
import paddle.nn as nn
import paddle.nn.functional as F

from .base_conv import BaseConv


def gcn_norm(edge_index, num_nodes):
_, col = edge_index[:, 0], edge_index[:, 1]
degree = paddle.zeros(shape=[num_nodes], dtype="int64")
degree = paddle.scatter(
x=degree,
index=col,
updates=paddle.ones_like(
col, dtype="int64"),
overwrite=False)
norm = paddle.cast(degree, dtype=paddle.get_default_dtype())
norm = paddle.clip(norm, min=1.0)
norm = paddle.pow(norm, -0.5)
norm = paddle.reshape(norm, [-1, 1])
return norm


class APPNPConv(BaseConv):
def __init__(self, alpha=0.2, k_hop=10, self_loop=False):
super(APPNPConv, self).__init__()
self.alpha = alpha
self.k_hop = k_hop
self.self_loop = self_loop

def forward(self, edge_index, num_nodes, feature, norm=None):
if self.self_loop:
index = paddle.arange(start=0, end=num_nodes, dtype="int64")
self_loop_edges = paddle.transpose(paddle.stack((index, index)), [1, 0])

mask = edge_index[:, 0] != edge_index[:, 1]
mask_index = paddle.masked_select(paddle.arange(end=edge_index.shape[0]), mask)
edges = paddle.gather(edge_index, mask_index) # remove self loop

edge_index = paddle.concat((self_loop_edges, edges), axis=0)
if norm is None:
norm = gcn_norm(edge_index, num_nodes)

h0 = feature

for _ in range(self.k_hop):
feature = feature * norm
feature = self.send_recv(edge_index, feature, "sum")
feature = feature * norm
feature = self.alpha * h0 + (1 - self.alpha) * feature

return feature


class APPNP(nn.Layer):
"""Implement of APPNP"""

def __init__(self,
input_size,
num_class,
num_layers=1,
hidden_size=64,
dropout=0.5,
k_hop=10,
alpha=0.1,
**kwargs):
super(APPNP, self).__init__()
self.num_class = num_class
self.num_layers = num_layers
self.hidden_size = hidden_size
self.dropout = dropout
self.alpha = alpha
self.k_hop = k_hop

self.mlps = nn.LayerList()
self.mlps.append(nn.Linear(input_size, self.hidden_size))
self.drop_fn = nn.Dropout(self.dropout)
for _ in range(self.num_layers - 1):
self.mlps.append(nn.Linear(self.hidden_size, self.hidden_size))

self.output = nn.Linear(self.hidden_size, num_class)
self.appnp = APPNPConv(alpha=self.alpha, k_hop=self.k_hop)

def forward(self, edge_index, num_nodes, feature):
for m in self.mlps:
feature = self.drop_fn(feature)
feature = m(feature)
feature = F.relu(feature)
feature = self.drop_fn(feature)
feature = self.output(feature)
feature = self.appnp(edge_index, num_nodes, feature)
return feature
2 changes: 1 addition & 1 deletion examples/gnn_depoly/models/base_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self):
def send_recv(self, edge_index, feature, pool_type="sum"):
src, dst = edge_index[:, 0], edge_index[:, 1]
return paddle.geometric.send_u_recv(
feature, src, dst, pool_type=pool_type)
feature, src, dst, reduce_op=pool_type)

def send(
self,
Expand Down
1 change: 1 addition & 0 deletions examples/gnn_depoly/models/gat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import pgl

from .base_conv import BaseConv

Expand Down
102 changes: 102 additions & 0 deletions examples/gnn_depoly/models/gin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from pgl.math import segment_pool

from .base_conv import BaseConv


class GINConv(BaseConv):
def __init__(
self, input_size, output_size, activation=None, init_eps=0.0, train_eps=False
):
super(GINConv, self).__init__()
self.input_size = input_size
self.output_size = output_size
self.linear1 = nn.Linear(input_size, output_size, bias_attr=True)
self.linear2 = nn.Linear(output_size, output_size, bias_attr=True)
self.layer_norm = nn.LayerNorm(output_size)
if train_eps:
self.epsilon = self.create_parameter(
shape=[1, 1],
dtype="float32",
default_initializer=nn.initializer.Constant(value=init_eps),
)
else:
self.epsilon = init_eps

if isinstance(activation, str):
activation = getattr(F, activation)
self.activation = activation

def forward(self, edge_index, feature):
neigh_feature = self.send_recv(edge_index, feature, "sum")
output = neigh_feature + feature * (self.epsilon + 1.0)

output = self.linear1(output)
output = self.layer_norm(output)

if self.activation is not None:
output = self.activation(output)

output = self.linear2(output)

return output


class GIN(nn.Layer):
def __init__(self, input_size, num_class, num_layers, hidden_size, pool_type="sum", dropout_prob=0.0):
super(GIN, self).__init__()
self.input_size = input_size
self.num_class = num_class
self.num_layers = num_layers
self.hidden_size = hidden_size
self.pool_type = pool_type
self.dropout_prob = dropout_prob

self.gin_convs = nn.LayerList()
self.norms = nn.LayerList()
self.linears = nn.LayerList()
self.linears.append(nn.Linear(self.input_size, self.num_class))

for i in range(self.num_layers):
if i == 0:
input_size = self.input_size
else:
input_size = self.hidden_size
gin = GINConv(input_size, self.hidden_size, "relu")
self.gin_convs.append(gin)
ln = paddle.nn.LayerNorm(self.hidden_size)
self.norms.append(ln)
self.linears.append(nn.Linear(self.hidden_size, self.num_class))
self.relu = nn.ReLU()


def forward(self, edge_index, feature):
feature_list = [feature]
for i in range(self.num_layers):
h = self.gin_convs[i](edge_index, feature_list[i])
h = self.norms[i](h)
h = self.relu(h)
feature_list.append(h)

output = 0
for i, h in enumerate(feature_list):
h = F.dropout(h, p=self.dropout_prob)
output += self.linears[i](h)

return output
92 changes: 92 additions & 0 deletions examples/gnn_depoly/models/sgc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import paddle
import paddle.nn as nn
import paddle.nn.functional as F

from .base_conv import BaseConv


def gcn_norm(edge_index, num_nodes):
_, col = edge_index[:, 0], edge_index[:, 1]
degree = paddle.zeros(shape=[num_nodes], dtype="int64")
degree = paddle.scatter(
x=degree,
index=col,
updates=paddle.ones_like(
col, dtype="int64"),
overwrite=False)
norm = paddle.cast(degree, dtype=paddle.get_default_dtype())
norm = paddle.clip(norm, min=1.0)
norm = paddle.pow(norm, -0.5)
norm = paddle.reshape(norm, [-1, 1])
return norm


class SGCConv(BaseConv):
def __init__(self, input_size, output_size, k_hop=2, cached=True, activation=None, bias=False):
super(SGCConv, self).__init__()
self.input_size = input_size
self.output_size = output_size
self.k_hop = k_hop
self.linear = nn.Linear(input_size, output_size, bias_attr=False)
if bias:
self.bias = self.create_parameter(shape=[output_size], is_bias=True)

self.cached = cached
self.cached_output = None
if isinstance(activation, str):
activation = getattr(F, activation)
self.activation = activation

def forward(self, edge_index, num_nodes, feature):
if self.cached:
if self.cached_output is None:
norm = gcn_norm(edge_index, num_nodes)
for hop in range(self.k_hop):
feature = feature * norm
feature = self.send_recv(edge_index, feature, "sum")
feature = feature * norm
self.cached_output = feature
else:
feature = self.cached_output
else:
norm = gcn_norm(edge_index, num_nodes)
for hop in range(self.k_hop):
feature = feature * norm
feature = self.send_recv(edge_index, feature, "sum")
feature = feature * norm

output = self.linear(feature)
if hasattr(self, "bias"):
output = output + self.bias

if self.activation is not None:
output = self.activation(output)
return output


class SGC(nn.Layer):
def __init__(self, input_size, num_class, num_layers=1, **kwargs):
super(SGC, self).__init__()
self.num_class = num_class
self.num_layers = num_layers
self.sgc_layer = SGCConv(
input_size=input_size, output_size=num_class, k_hop=num_layers)

def forward(self, edge_index, num_nodes, feature):
feature = self.sgc_layer(edge_index, num_nodes, feature)
return feature
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pgl

sys.path.insert(0, os.path.abspath(".."))
from models import GCN, GAT, GraphSage
from models import GCN, GAT, GraphSage, GIN, SGC, APPNP


def save_static_model(args):
Expand All @@ -35,6 +35,21 @@ def save_static_model(args):
attn_drop=0.6,
num_heads=8,
hidden_size=8)
elif args.model == "APPNP":
model = APPNP(input_size=dataset.graph.node_feat["words"].shape[1],
num_class=dataset.num_classes,
num_layers=2,
alpha=0.1,
k_hop=10)
elif args.model == "SGC":
model = SGC(input_size=dataset.graph.node_feat["words"].shape[1],
num_class=dataset.num_classes,
num_layers=2)
elif args.model == "GIN":
model = GIN(input_size=dataset.graph.node_feat["words"].shape[1],
num_class=dataset.num_classes,
num_layers=2,
hidden_size=16)
elif args.model == "GCN":
model = GCN(input_size=dataset.graph.node_feat["words"].shape[1],
num_class=dataset.num_classes,
Expand All @@ -54,7 +69,7 @@ def save_static_model(args):
model.eval()

# Convert to static graph with specific input description
if args.model == "GraphSage":
if args.model == "GraphSage" or args.model == "GIN":
model = paddle.jit.to_static(
model,
input_spec=[
Expand Down
Loading

0 comments on commit c5a9592

Please sign in to comment.