Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add unbatch & test #317

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 141 additions & 1 deletion pgl/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,7 @@ def _process_graph_info(self, **kwargs):

@classmethod
def disjoint(cls, graph_list, merged_graph_index=False):
"""This method disjoint list of graph into a big graph.
"""This method disjoint list of graphs into a big graph.

Args:

Expand Down Expand Up @@ -1150,11 +1150,68 @@ def disjoint(cls, graph_list, merged_graph_index=False):
_graph_edge_index=graph_edge_index)
return graph

@classmethod
def split(cls, graph):
"""This method split graph according to graph_node_id into list of graphs.

Args:

graph (Graph): A Graph.

.. code-block:: python

import numpy as np
import pgl

num_nodes = 5
edges = [ (0, 1), (1, 2), (3, 4)]
graph_1 = pgl.Graph(num_nodes=num_nodes,
edges=edges)

num_nodes = 3
edges = [ (0, 1), (1, 2)]
graph_2 = pgl.Graph(num_nodes=num_nodes,
edges=edges)

joint_graph = pgl.Graph.disjoint([graph_1, graph_2])
print(joint_graph.graph_node_id)
>>> [0, 0, 0, 0, 0, 1, 1, 1]
print(joint_graph.num_graph)
>>> 2

graph_list = pgl.Graph.split(joint_graph)
print(graph_list[0].graph_node_id)
>>> [0, 0, 0, 0, 0]
print(graph_list[1].graph_node_id)
>>> [0, 0, 0]
"""

edges_list = cls._split_edges(graph)
num_nodes_list = cls._split_nodes(graph)
node_feat_list = cls._split_feature(graph, mode="node")
edge_feat_list = cls._split_feature(graph, mode="edge")

graph_list = []
for edges, num_nodes, node_feat, edge_feat in zip(
edges_list, num_nodes_list, node_feat_list, edge_feat_list):
graph = cls(num_nodes=num_nodes,
edges=edges,
node_feat=node_feat,
edge_feat=edge_feat)
graph_list.append(graph)

return graph_list

@staticmethod
def batch(graph_list):
"""This is alias on `pgl.Graph.disjoint` with merged_graph_index=False"""
return Graph.disjoint(graph_list, merged_graph_index=False)

@staticmethod
def unbatch(graph):
"""This is alias on `pgl.Graph.split`"""
return Graph.split(graph)

@staticmethod
def _join_graph_index(graph_list, mode="node"):
is_tensor = graph_list[0].is_tensor()
Expand All @@ -1178,6 +1235,20 @@ def _join_nodes(graph_list):
num_nodes = g.num_nodes + num_nodes
return num_nodes

@staticmethod
def _split_nodes(graph):
is_tensor = graph.is_tensor()
num_nodes_list = []
start_list, end_list = graph._graph_node_index[:
-1], graph._graph_node_index[
1:]
if is_tensor:
start_list = start_list.numpy().tolist()
end_list = end_list.numpy().tolist()
for start, end in zip(start_list, end_list):
num_nodes_list.append(end - start)
return num_nodes_list

@staticmethod
def _join_feature(graph_list, mode="node"):
"""join node features for multiple graph"""
Expand Down Expand Up @@ -1207,6 +1278,54 @@ def _join_feature(graph_list, mode="node"):
ret_feat[key] = np.concatenate(feat[key], axis=0)
return ret_feat

@staticmethod
def _split_feature(graph, mode="node"):
"""split node features of graph according to graph_node_id"""
is_tensor = graph.is_tensor()
feat_list = []
if mode == "node":
start_list, end_list = graph._graph_node_index[:
-1], graph._graph_node_index[
1:]
if is_tensor:
start_list = start_list.numpy().tolist()
end_list = end_list.numpy().tolist()
for start, end in zip(start_list, end_list):
feat = defaultdict(lambda: [])
for key in graph.node_feat:
feat[key].append(graph.node_feat[key][start:end])
feat_list.append(feat)
elif mode == "edge":
start_list, end_list = graph._graph_edge_index[:
-1], graph._graph_edge_index[
1:]
if is_tensor:
start_list = start_list.numpy().tolist()
end_list = end_list.numpy().tolist()
for start, end in zip(start_list, end_list):
feat = defaultdict(lambda: [])
for key in graph.edge_feat:
feat[key].append(graph.edge_feat[key][start:end])
feat_list.append(feat)
else:
raise ValueError(
"mode must be in ['node', 'edge']. But received model=%s" %
mode)

feat_list_temp = []
for feat in feat_list:
ret_feat = {}
for key in feat:
if len(feat[key]) == 1:
ret_feat[key] = feat[key][0]
else:
if is_tensor:
ret_feat[key] = paddle.concat(feat[key], 0)
else:
ret_feat[key] = np.concatenate(feat[key], axis=0)
feat_list_temp.append(ret_feat)
return feat_list_temp

@staticmethod
def _join_edges(graph_list):
"""join edges for multiple graph"""
Expand All @@ -1228,6 +1347,27 @@ def _join_edges(graph_list):
edges = np.concatenate(list_edges, axis=0)
return edges

@staticmethod
def _split_edges(graph):
"""split edges according to graph_edge_id"""
is_tensor = graph.is_tensor()
start_offset_list = graph._graph_node_index[:-1]
start_list, end_list = graph._graph_edge_index[:
-1], graph._graph_edge_index[
1:]
if is_tensor:
start_list = start_list.numpy().tolist()
end_list = end_list.numpy().tolist()
start_offset_list = start_offset_list.numpy().tolist()

edges_list = []
for start, end, start_offset in zip(start_list, end_list,
start_offset_list):
edges = graph.edges[start:end]
edges -= start_offset
edges_list.append(edges)
return edges_list

def node_batch_iter(self, batch_size, shuffle=True):
"""Node batch iterator

Expand Down
73 changes: 73 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,79 @@ def test_node_iter(self):
batch_size=3, shuffle=shuffle):
break

def test_batch_and_unbatch(self):
graph_1 = create_random_graph().numpy()
graph_2 = create_random_graph().numpy()
joint_graph = pgl.Graph.batch([graph_1, graph_2])
graph_list = pgl.Graph.unbatch(joint_graph)

self.assertTrue((graph_list[0].graph_node_id == graph_1.graph_node_id
).all())
self.assertTrue((graph_list[1].graph_node_id == graph_2.graph_node_id
).all())
self.assertTrue((graph_list[0].graph_edge_id == graph_1.graph_edge_id
).all())
self.assertTrue((graph_list[1].graph_edge_id == graph_2.graph_edge_id
).all())

self.assertTrue((graph_list[0].node_feat["nfeat"] == graph_1.node_feat[
"nfeat"]).all())
self.assertTrue((graph_list[1].node_feat["nfeat"] == graph_2.node_feat[
"nfeat"]).all())
self.assertTrue((graph_list[0].edge_feat["efeat"] == graph_1.edge_feat[
"efeat"]).all())
self.assertTrue((graph_list[1].edge_feat["efeat"] == graph_2.edge_feat[
"efeat"]).all())

graph_1 = create_random_graph().tensor()
graph_2 = create_random_graph().tensor()
joint_graph = pgl.Graph.batch([graph_1, graph_2])
graph_list = pgl.Graph.unbatch(joint_graph)

self.assertTrue((graph_list[0].graph_node_id == graph_1.graph_node_id
).all())
self.assertTrue((graph_list[1].graph_node_id == graph_2.graph_node_id
).all())
self.assertTrue((graph_list[0].graph_edge_id == graph_1.graph_edge_id
).all())
self.assertTrue((graph_list[1].graph_edge_id == graph_2.graph_edge_id
).all())

self.assertTrue((graph_list[0].node_feat["nfeat"].numpy() ==
graph_1.node_feat["nfeat"].numpy()).all())
self.assertTrue((graph_list[1].node_feat["nfeat"].numpy() ==
graph_2.node_feat["nfeat"].numpy()).all())
self.assertTrue((graph_list[0].edge_feat["efeat"].numpy() ==
graph_1.edge_feat["efeat"].numpy()).all())
self.assertTrue((graph_list[1].edge_feat["efeat"].numpy() ==
graph_2.edge_feat["efeat"].numpy()).all())

graph_1 = create_random_graph()
joint_graph = pgl.Graph.batch([graph_1])
graph_list = pgl.Graph.unbatch(joint_graph)

self.assertTrue((graph_list[0].graph_node_id == graph_1.graph_node_id
).all())
self.assertTrue((graph_list[0].graph_edge_id == graph_1.graph_edge_id
).all())
self.assertTrue((graph_list[0].node_feat["nfeat"] == graph_1.node_feat[
"nfeat"]).all())
self.assertTrue((graph_list[0].edge_feat["efeat"] == graph_1.edge_feat[
"efeat"]).all())

graph_1 = create_random_graph().tensor()
joint_graph = pgl.Graph.batch([graph_1])
graph_list = pgl.Graph.unbatch(joint_graph)

self.assertTrue((graph_list[0].graph_node_id == graph_1.graph_node_id
).all())
self.assertTrue((graph_list[0].graph_edge_id == graph_1.graph_edge_id
).all())
self.assertTrue((graph_list[0].node_feat["nfeat"].numpy() ==
graph_1.node_feat["nfeat"].numpy()).all())
self.assertTrue((graph_list[0].edge_feat["efeat"].numpy() ==
graph_1.edge_feat["efeat"].numpy()).all())


if __name__ == "__main__":
unittest.main()