From a50e399e6df0ca6628ba7a98b8958c460ad69dd3 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Mon, 23 Sep 2024 20:41:16 +0800 Subject: [PATCH] add appnp, gat, gin --- examples/gnn_depoly/models/__init__.py | 3 + examples/gnn_depoly/models/appnp.py | 107 ++++++++++++++++++ examples/gnn_depoly/models/base_conv.py | 2 +- examples/gnn_depoly/models/gin.py | 102 +++++++++++++++++ examples/gnn_depoly/models/sgc.py | 92 +++++++++++++++ .../export_model.py | 19 +++- .../train.py | 21 +++- 7 files changed, 340 insertions(+), 6 deletions(-) create mode 100644 examples/gnn_depoly/models/appnp.py create mode 100644 examples/gnn_depoly/models/gin.py create mode 100644 examples/gnn_depoly/models/sgc.py diff --git a/examples/gnn_depoly/models/__init__.py b/examples/gnn_depoly/models/__init__.py index 6bed5aa0..f33508d7 100644 --- a/examples/gnn_depoly/models/__init__.py +++ b/examples/gnn_depoly/models/__init__.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .appnp import APPNP from .gcn import GCN from .gat import GAT +from .gin import GIN from .graphsage import GraphSage +from .sgc import SGC diff --git a/examples/gnn_depoly/models/appnp.py b/examples/gnn_depoly/models/appnp.py new file mode 100644 index 00000000..3e15a494 --- /dev/null +++ b/examples/gnn_depoly/models/appnp.py @@ -0,0 +1,107 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from .base_conv import BaseConv + + +def gcn_norm(edge_index, num_nodes): + _, col = edge_index[:, 0], edge_index[:, 1] + degree = paddle.zeros(shape=[num_nodes], dtype="int64") + degree = paddle.scatter( + x=degree, + index=col, + updates=paddle.ones_like( + col, dtype="int64"), + overwrite=False) + norm = paddle.cast(degree, dtype=paddle.get_default_dtype()) + norm = paddle.clip(norm, min=1.0) + norm = paddle.pow(norm, -0.5) + norm = paddle.reshape(norm, [-1, 1]) + return norm + + +class APPNPConv(BaseConv): + def __init__(self, alpha=0.2, k_hop=10, self_loop=False): + super(APPNPConv, self).__init__() + self.alpha = alpha + self.k_hop = k_hop + self.self_loop = self_loop + + def forward(self, edge_index, num_nodes, feature, norm=None): + if self.self_loop: + index = paddle.arange(start=0, end=num_nodes, dtype="int64") + self_loop_edges = paddle.transpose(paddle.stack((index, index)), [1, 0]) + + mask = edge_index[:, 0] != edge_index[:, 1] + mask_index = paddle.masked_select(paddle.arange(end=edge_index.shape[0]), mask) + edges = paddle.gather(edge_index, mask_index) # remove self loop + + edge_index = paddle.concat((self_loop_edges, edges), axis=0) + if norm is None: + norm = gcn_norm(edge_index, num_nodes) + + h0 = feature + + for _ in range(self.k_hop): + feature = feature * norm + feature = self.send_recv(edge_index, feature, "sum") + feature = feature * norm + feature = self.alpha * h0 + (1 - self.alpha) * feature + + return feature + + +class APPNP(nn.Layer): + """Implement of APPNP""" + + def __init__(self, + input_size, + num_class, + num_layers=1, + hidden_size=64, + dropout=0.5, + k_hop=10, + alpha=0.1, + **kwargs): + super(APPNP, self).__init__() + self.num_class = num_class + self.num_layers = num_layers + self.hidden_size = hidden_size + self.dropout = dropout + self.alpha = alpha + self.k_hop = k_hop + + self.mlps = nn.LayerList() + self.mlps.append(nn.Linear(input_size, self.hidden_size)) + self.drop_fn = nn.Dropout(self.dropout) + for _ in range(self.num_layers - 1): + self.mlps.append(nn.Linear(self.hidden_size, self.hidden_size)) + + self.output = nn.Linear(self.hidden_size, num_class) + self.appnp = APPNPConv(alpha=self.alpha, k_hop=self.k_hop) + + def forward(self, edge_index, num_nodes, feature): + for m in self.mlps: + feature = self.drop_fn(feature) + feature = m(feature) + feature = F.relu(feature) + feature = self.drop_fn(feature) + feature = self.output(feature) + feature = self.appnp(edge_index, num_nodes, feature) + return feature diff --git a/examples/gnn_depoly/models/base_conv.py b/examples/gnn_depoly/models/base_conv.py index adde295f..3f5069ff 100644 --- a/examples/gnn_depoly/models/base_conv.py +++ b/examples/gnn_depoly/models/base_conv.py @@ -31,7 +31,7 @@ def __init__(self): def send_recv(self, edge_index, feature, pool_type="sum"): src, dst = edge_index[:, 0], edge_index[:, 1] return paddle.geometric.send_u_recv( - feature, src, dst, pool_type=pool_type) + feature, src, dst, reduce_op=pool_type) def send( self, diff --git a/examples/gnn_depoly/models/gin.py b/examples/gnn_depoly/models/gin.py new file mode 100644 index 00000000..46945b8a --- /dev/null +++ b/examples/gnn_depoly/models/gin.py @@ -0,0 +1,102 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from pgl.math import segment_pool + +from .base_conv import BaseConv + + +class GINConv(BaseConv): + def __init__( + self, input_size, output_size, activation=None, init_eps=0.0, train_eps=False + ): + super(GINConv, self).__init__() + self.input_size = input_size + self.output_size = output_size + self.linear1 = nn.Linear(input_size, output_size, bias_attr=True) + self.linear2 = nn.Linear(output_size, output_size, bias_attr=True) + self.layer_norm = nn.LayerNorm(output_size) + if train_eps: + self.epsilon = self.create_parameter( + shape=[1, 1], + dtype="float32", + default_initializer=nn.initializer.Constant(value=init_eps), + ) + else: + self.epsilon = init_eps + + if isinstance(activation, str): + activation = getattr(F, activation) + self.activation = activation + + def forward(self, edge_index, feature): + neigh_feature = self.send_recv(edge_index, feature, "sum") + output = neigh_feature + feature * (self.epsilon + 1.0) + + output = self.linear1(output) + output = self.layer_norm(output) + + if self.activation is not None: + output = self.activation(output) + + output = self.linear2(output) + + return output + + +class GIN(nn.Layer): + def __init__(self, input_size, num_class, num_layers, hidden_size, pool_type="sum", dropout_prob=0.0): + super(GIN, self).__init__() + self.input_size = input_size + self.num_class = num_class + self.num_layers = num_layers + self.hidden_size = hidden_size + self.pool_type = pool_type + self.dropout_prob = dropout_prob + + self.gin_convs = nn.LayerList() + self.norms = nn.LayerList() + self.linears = nn.LayerList() + self.linears.append(nn.Linear(self.input_size, self.num_class)) + + for i in range(self.num_layers): + if i == 0: + input_size = self.input_size + else: + input_size = self.hidden_size + gin = GINConv(input_size, self.hidden_size, "relu") + self.gin_convs.append(gin) + ln = paddle.nn.LayerNorm(self.hidden_size) + self.norms.append(ln) + self.linears.append(nn.Linear(self.hidden_size, self.num_class)) + self.relu = nn.ReLU() + + + def forward(self, edge_index, feature): + feature_list = [feature] + for i in range(self.num_layers): + h = self.gin_convs[i](edge_index, feature_list[i]) + h = self.norms[i](h) + h = self.relu(h) + feature_list.append(h) + + output = 0 + for i, h in enumerate(feature_list): + h = F.dropout(h, p=self.dropout_prob) + output += self.linears[i](h) + + return output \ No newline at end of file diff --git a/examples/gnn_depoly/models/sgc.py b/examples/gnn_depoly/models/sgc.py new file mode 100644 index 00000000..558acb0c --- /dev/null +++ b/examples/gnn_depoly/models/sgc.py @@ -0,0 +1,92 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from .base_conv import BaseConv + + +def gcn_norm(edge_index, num_nodes): + _, col = edge_index[:, 0], edge_index[:, 1] + degree = paddle.zeros(shape=[num_nodes], dtype="int64") + degree = paddle.scatter( + x=degree, + index=col, + updates=paddle.ones_like( + col, dtype="int64"), + overwrite=False) + norm = paddle.cast(degree, dtype=paddle.get_default_dtype()) + norm = paddle.clip(norm, min=1.0) + norm = paddle.pow(norm, -0.5) + norm = paddle.reshape(norm, [-1, 1]) + return norm + + +class SGCConv(BaseConv): + def __init__(self, input_size, output_size, k_hop=2, cached=True, activation=None, bias=False): + super(SGCConv, self).__init__() + self.input_size = input_size + self.output_size = output_size + self.k_hop = k_hop + self.linear = nn.Linear(input_size, output_size, bias_attr=False) + if bias: + self.bias = self.create_parameter(shape=[output_size], is_bias=True) + + self.cached = cached + self.cached_output = None + if isinstance(activation, str): + activation = getattr(F, activation) + self.activation = activation + + def forward(self, edge_index, num_nodes, feature): + if self.cached: + if self.cached_output is None: + norm = gcn_norm(edge_index, num_nodes) + for hop in range(self.k_hop): + feature = feature * norm + feature = self.send_recv(edge_index, feature, "sum") + feature = feature * norm + self.cached_output = feature + else: + feature = self.cached_output + else: + norm = gcn_norm(edge_index, num_nodes) + for hop in range(self.k_hop): + feature = feature * norm + feature = self.send_recv(edge_index, feature, "sum") + feature = feature * norm + + output = self.linear(feature) + if hasattr(self, "bias"): + output = output + self.bias + + if self.activation is not None: + output = self.activation(output) + return output + + +class SGC(nn.Layer): + def __init__(self, input_size, num_class, num_layers=1, **kwargs): + super(SGC, self).__init__() + self.num_class = num_class + self.num_layers = num_layers + self.sgc_layer = SGCConv( + input_size=input_size, output_size=num_class, k_hop=num_layers) + + def forward(self, edge_index, num_nodes, feature): + feature = self.sgc_layer(edge_index, num_nodes, feature) + return feature diff --git a/examples/gnn_depoly/node_classification_with_full_graph/export_model.py b/examples/gnn_depoly/node_classification_with_full_graph/export_model.py index ff5a81e9..82320854 100644 --- a/examples/gnn_depoly/node_classification_with_full_graph/export_model.py +++ b/examples/gnn_depoly/node_classification_with_full_graph/export_model.py @@ -21,7 +21,7 @@ import pgl sys.path.insert(0, os.path.abspath("..")) -from models import GCN, GAT, GraphSage +from models import GCN, GAT, GraphSage, GIN, SGC, APPNP def save_static_model(args): @@ -35,6 +35,21 @@ def save_static_model(args): attn_drop=0.6, num_heads=8, hidden_size=8) + elif args.model == "APPNP": + model = APPNP(input_size=dataset.graph.node_feat["words"].shape[1], + num_class=dataset.num_classes, + num_layers=2, + alpha=0.1, + k_hop=10) + elif args.model == "SGC": + model = SGC(input_size=dataset.graph.node_feat["words"].shape[1], + num_class=dataset.num_classes, + num_layers=2) + elif args.model == "GIN": + model = GIN(input_size=dataset.graph.node_feat["words"].shape[1], + num_class=dataset.num_classes, + num_layers=2, + hidden_size=16) elif args.model == "GCN": model = GCN(input_size=dataset.graph.node_feat["words"].shape[1], num_class=dataset.num_classes, @@ -54,7 +69,7 @@ def save_static_model(args): model.eval() # Convert to static graph with specific input description - if args.model == "GraphSage": + if args.model == "GraphSage" or args.model == "GIN": model = paddle.jit.to_static( model, input_spec=[ diff --git a/examples/gnn_depoly/node_classification_with_full_graph/train.py b/examples/gnn_depoly/node_classification_with_full_graph/train.py index 139be643..b3c7719f 100644 --- a/examples/gnn_depoly/node_classification_with_full_graph/train.py +++ b/examples/gnn_depoly/node_classification_with_full_graph/train.py @@ -24,7 +24,7 @@ from paddle.optimizer import Adam sys.path.insert(0, os.path.abspath("..")) -from models import GCN, GAT, GraphSage +from models import GCN, GAT, GraphSage, GIN, SGC, APPNP def normalize(feat): @@ -59,7 +59,7 @@ def load(): def train(node_index, node_label, gnn_model, graph, criterion, optim, args): gnn_model.train() - if args.model == "GraphSage": + if args.model == "GraphSage" or args.model == "GIN": pred = gnn_model(graph.edges, graph.node_feat["words"]) else: pred = gnn_model(graph.edges, graph.num_nodes, @@ -76,7 +76,7 @@ def train(node_index, node_label, gnn_model, graph, criterion, optim, args): @paddle.no_grad() def eval(node_index, node_label, gnn_model, graph, criterion, args): gnn_model.eval() - if args.model == "GraphSage": + if args.model == "GraphSage" or args.model == "GIN": pred = gnn_model(graph.edges, graph.node_feat["words"]) else: pred = gnn_model(graph.edges, graph.num_nodes, @@ -108,6 +108,21 @@ def main(args): attn_drop=0.6, num_heads=8, hidden_size=8) + elif args.model == "SGC": + gnn_model = SGC(input_size=graph.node_feat["words"].shape[1], + num_class=dataset.num_classes, + num_layers=2) + elif args.model == "APPNP": + gnn_model = APPNP(input_size=graph.node_feat["words"].shape[1], + num_class=dataset.num_classes, + num_layers=2, + alpha=0.1, + k_hop=10) + elif args.model == "GIN": + gnn_model = GIN(input_size=graph.node_feat["words"].shape[1], + num_class=dataset.num_classes, + num_layers=2, + hidden_size=16) elif args.model == "GCN": gnn_model = GCN(input_size=graph.node_feat["words"].shape[1], num_class=dataset.num_classes,