From f804346bb8f345f9a91c9a9f7b9e492b32e8f5b1 Mon Sep 17 00:00:00 2001 From: shiyunsheng01 Date: Wed, 25 Aug 2021 13:50:26 +0800 Subject: [PATCH 1/2] add unbatch & test --- pgl/graph.py | 140 ++++++++++++++++++++++++++++++++++++++++++++ tests/test_graph.py | 73 +++++++++++++++++++++++ 2 files changed, 213 insertions(+) diff --git a/pgl/graph.py b/pgl/graph.py index 50aaccc4..849adbeb 100644 --- a/pgl/graph.py +++ b/pgl/graph.py @@ -1150,11 +1150,68 @@ def disjoint(cls, graph_list, merged_graph_index=False): _graph_edge_index=graph_edge_index) return graph + @classmethod + def split(cls, graph): + """This method split graph according to graph_node_id into list of graph. + + Args: + + graph (Graph): A Graph. + + .. code-block:: python + + import numpy as np + import pgl + + num_nodes = 5 + edges = [ (0, 1), (1, 2), (3, 4)] + graph_1 = pgl.Graph(num_nodes=num_nodes, + edges=edges) + + num_nodes = 3 + edges = [ (0, 1), (1, 2)] + graph_2 = pgl.Graph(num_nodes=num_nodes, + edges=edges) + + joint_graph = pgl.Graph.disjoint([graph_1, graph_2]) + print(joint_graph.graph_node_id) + >>> [0, 0, 0, 0, 0, 1, 1, 1] + print(joint_graph.num_graph) + >>> 2 + + graph_list = pgl.Graph.split(joint_graph) + print(graph_list[0].graph_node_id) + >>> [0, 0, 0, 0, 0] + print(graph_list[1].graph_node_id) + >>> [0, 0, 0] + """ + + edges_list = cls._split_edges(graph) + num_nodes_list = cls._split_nodes(graph) + node_feat_list = cls._split_feature(graph, mode="node") + edge_feat_list = cls._split_feature(graph, mode="edge") + + graph_list = [] + for edges, num_nodes, node_feat, edge_feat in zip( + edges_list, num_nodes_list, node_feat_list, edge_feat_list): + graph = cls(num_nodes=num_nodes, + edges=edges, + node_feat=node_feat, + edge_feat=edge_feat) + graph_list.append(graph) + + return graph_list + @staticmethod def batch(graph_list): """This is alias on `pgl.Graph.disjoint` with merged_graph_index=False""" return Graph.disjoint(graph_list, merged_graph_index=False) + @staticmethod + def unbatch(graph): + """This is alias on `pgl.Graph.split`""" + return Graph.split(graph) + @staticmethod def _join_graph_index(graph_list, mode="node"): is_tensor = graph_list[0].is_tensor() @@ -1178,6 +1235,20 @@ def _join_nodes(graph_list): num_nodes = g.num_nodes + num_nodes return num_nodes + @staticmethod + def _split_nodes(graph): + is_tensor = graph.is_tensor() + num_nodes_list = [] + start_list, end_list = graph._graph_node_index[: + -1], graph._graph_node_index[ + 1:] + if is_tensor: + start_list = start_list.numpy().tolist() + end_list = end_list.numpy().tolist() + for start, end in zip(start_list, end_list): + num_nodes_list.append(end - start) + return num_nodes_list + @staticmethod def _join_feature(graph_list, mode="node"): """join node features for multiple graph""" @@ -1207,6 +1278,54 @@ def _join_feature(graph_list, mode="node"): ret_feat[key] = np.concatenate(feat[key], axis=0) return ret_feat + @staticmethod + def _split_feature(graph, mode="node"): + """split node features of graph according to graph_node_id""" + is_tensor = graph.is_tensor() + feat_list = [] + if mode == "node": + start_list, end_list = graph._graph_node_index[: + -1], graph._graph_node_index[ + 1:] + if is_tensor: + start_list = start_list.numpy().tolist() + end_list = end_list.numpy().tolist() + for start, end in zip(start_list, end_list): + feat = defaultdict(lambda: []) + for key in graph.node_feat: + feat[key].append(graph.node_feat[key][start:end]) + feat_list.append(feat) + elif mode == "edge": + start_list, end_list = graph._graph_edge_index[: + -1], graph._graph_edge_index[ + 1:] + if is_tensor: + start_list = start_list.numpy().tolist() + end_list = end_list.numpy().tolist() + for start, end in zip(start_list, end_list): + feat = defaultdict(lambda: []) + for key in graph.edge_feat: + feat[key].append(graph.edge_feat[key][start:end]) + feat_list.append(feat) + else: + raise ValueError( + "mode must be in ['node', 'edge']. But received model=%s" % + mode) + + feat_list_temp = [] + for feat in feat_list: + ret_feat = {} + for key in feat: + if len(feat[key]) == 1: + ret_feat[key] = feat[key][0] + else: + if is_tensor: + ret_feat[key] = paddle.concat(feat[key], 0) + else: + ret_feat[key] = np.concatenate(feat[key], axis=0) + feat_list_temp.append(ret_feat) + return feat_list_temp + @staticmethod def _join_edges(graph_list): """join edges for multiple graph""" @@ -1228,6 +1347,27 @@ def _join_edges(graph_list): edges = np.concatenate(list_edges, axis=0) return edges + @staticmethod + def _split_edges(graph): + """split edges according to graph_edge_id""" + is_tensor = graph.is_tensor() + start_offset_list = graph._graph_node_index[:-1] + start_list, end_list = graph._graph_edge_index[: + -1], graph._graph_edge_index[ + 1:] + if is_tensor: + start_list = start_list.numpy().tolist() + end_list = end_list.numpy().tolist() + start_offset_list = start_offset_list.numpy().tolist() + + edges_list = [] + for start, end, start_offset in zip(start_list, end_list, + start_offset_list): + edges = graph.edges[start:end] + edges -= start_offset + edges_list.append(edges) + return edges_list + def node_batch_iter(self, batch_size, shuffle=True): """Node batch iterator diff --git a/tests/test_graph.py b/tests/test_graph.py index 10eefaf2..dfa1a991 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -435,6 +435,79 @@ def test_node_iter(self): batch_size=3, shuffle=shuffle): break + def test_batch_and_unbatch(self): + graph_1 = create_random_graph().numpy() + graph_2 = create_random_graph().numpy() + joint_graph = pgl.Graph.batch([graph_1, graph_2]) + graph_list = pgl.Graph.unbatch(joint_graph) + + self.assertTrue((graph_list[0].graph_node_id == graph_1.graph_node_id + ).all()) + self.assertTrue((graph_list[1].graph_node_id == graph_2.graph_node_id + ).all()) + self.assertTrue((graph_list[0].graph_edge_id == graph_1.graph_edge_id + ).all()) + self.assertTrue((graph_list[1].graph_edge_id == graph_2.graph_edge_id + ).all()) + + self.assertTrue((graph_list[0].node_feat["nfeat"] == graph_1.node_feat[ + "nfeat"]).all()) + self.assertTrue((graph_list[1].node_feat["nfeat"] == graph_2.node_feat[ + "nfeat"]).all()) + self.assertTrue((graph_list[0].edge_feat["efeat"] == graph_1.edge_feat[ + "efeat"]).all()) + self.assertTrue((graph_list[1].edge_feat["efeat"] == graph_2.edge_feat[ + "efeat"]).all()) + + graph_1 = create_random_graph().tensor() + graph_2 = create_random_graph().tensor() + joint_graph = pgl.Graph.batch([graph_1, graph_2]) + graph_list = pgl.Graph.unbatch(joint_graph) + + self.assertTrue((graph_list[0].graph_node_id == graph_1.graph_node_id + ).all()) + self.assertTrue((graph_list[1].graph_node_id == graph_2.graph_node_id + ).all()) + self.assertTrue((graph_list[0].graph_edge_id == graph_1.graph_edge_id + ).all()) + self.assertTrue((graph_list[1].graph_edge_id == graph_2.graph_edge_id + ).all()) + + self.assertTrue((graph_list[0].node_feat["nfeat"].numpy() == + graph_1.node_feat["nfeat"].numpy()).all()) + self.assertTrue((graph_list[1].node_feat["nfeat"].numpy() == + graph_2.node_feat["nfeat"].numpy()).all()) + self.assertTrue((graph_list[0].edge_feat["efeat"].numpy() == + graph_1.edge_feat["efeat"].numpy()).all()) + self.assertTrue((graph_list[1].edge_feat["efeat"].numpy() == + graph_2.edge_feat["efeat"].numpy()).all()) + + graph_1 = create_random_graph() + joint_graph = pgl.Graph.batch([graph_1]) + graph_list = pgl.Graph.unbatch(joint_graph) + + self.assertTrue((graph_list[0].graph_node_id == graph_1.graph_node_id + ).all()) + self.assertTrue((graph_list[0].graph_edge_id == graph_1.graph_edge_id + ).all()) + self.assertTrue((graph_list[0].node_feat["nfeat"] == graph_1.node_feat[ + "nfeat"]).all()) + self.assertTrue((graph_list[0].edge_feat["efeat"] == graph_1.edge_feat[ + "efeat"]).all()) + + graph_1 = create_random_graph().tensor() + joint_graph = pgl.Graph.batch([graph_1]) + graph_list = pgl.Graph.unbatch(joint_graph) + + self.assertTrue((graph_list[0].graph_node_id == graph_1.graph_node_id + ).all()) + self.assertTrue((graph_list[0].graph_edge_id == graph_1.graph_edge_id + ).all()) + self.assertTrue((graph_list[0].node_feat["nfeat"].numpy() == + graph_1.node_feat["nfeat"].numpy()).all()) + self.assertTrue((graph_list[0].edge_feat["efeat"].numpy() == + graph_1.edge_feat["efeat"].numpy()).all()) + if __name__ == "__main__": unittest.main() From 50a98958a93844982cf6cac7812894741b2f65cd Mon Sep 17 00:00:00 2001 From: shiyunsheng01 Date: Wed, 25 Aug 2021 14:56:30 +0800 Subject: [PATCH 2/2] add unbatch & test --- pgl/graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pgl/graph.py b/pgl/graph.py index 849adbeb..1ae47970 100644 --- a/pgl/graph.py +++ b/pgl/graph.py @@ -1088,7 +1088,7 @@ def _process_graph_info(self, **kwargs): @classmethod def disjoint(cls, graph_list, merged_graph_index=False): - """This method disjoint list of graph into a big graph. + """This method disjoint list of graphs into a big graph. Args: @@ -1152,7 +1152,7 @@ def disjoint(cls, graph_list, merged_graph_index=False): @classmethod def split(cls, graph): - """This method split graph according to graph_node_id into list of graph. + """This method split graph according to graph_node_id into list of graphs. Args: