Skip to content

Commit

Permalink
Convert large function to SkeletonDecoder class
Browse files Browse the repository at this point in the history
  • Loading branch information
roomrys committed Sep 12, 2024
1 parent 34a8cfa commit 8b2005a
Showing 1 changed file with 68 additions and 85 deletions.
153 changes: 68 additions & 85 deletions sleap/skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ def matches(self, other: "Node") -> bool:
return other.name == self.name and other.weight == self.weight


def replace_jsonpickle_decode(json_str: str) -> Any:
"""Replace jsonpickle.decode with own decoder.
class SkeletonDecoder:
"""Replace jsonpickle.decode with our own decoder.
This function will decode the following from their encoded format:
This function will decode the following from jsonpickle's encoded format:
`Node` objects from
{
Expand Down Expand Up @@ -119,9 +119,10 @@ def replace_jsonpickle_decode(json_str: str) -> Any:
to the object with the same reconstruction id (from top to bottom).
"""

def decode_id(
id: int, objects: List[Union[Node, EdgeType]]
) -> Union[Node, EdgeType]:
def __init__(self):
self.decoded_objects: List[Union[Node, EdgeType]] = []

def _decode_id(self, id: int) -> Union[Node, EdgeType]:
"""Decode the object with the given `py/id` value of `id`.
Args:
Expand All @@ -131,9 +132,10 @@ def decode_id(
Returns:
The object with the given `py/id` value.
"""
return objects[id - 1]
return self.decoded_objects[id - 1]

def decode_state(state: dict) -> Node:
@staticmethod
def _decode_state(state: dict) -> Node:
"""Reconstruct the `Node` object from 'py/state' key in the serialized nx_graph.
We support states in either dictionary or tuple format:
Expand All @@ -157,7 +159,8 @@ def decode_state(state: dict) -> Node:

return Node(**state)

def decode_object_dict(object_dict) -> Node:
@staticmethod
def _decode_object_dict(object_dict) -> Node:
"""Decode dict containing `py/object` key in the serialized nx_graph.
Args:
Expand All @@ -175,36 +178,33 @@ def decode_object_dict(object_dict) -> Node:
if object_dict["py/object"] != "sleap.skeleton.Node":
raise ValueError("Only 'sleap.skeleton.Node' objects are supported.")

node: Node = decode_state(object_dict["py/state"])
node: Node = SkeletonDecoder._decode_state(state=object_dict["py/state"])
return node

def decode_node(
encoded_node: dict, decoded_objects: List[Union[Node, EdgeType]]
) -> Tuple[Node, List[Union[Node, EdgeType]]]:
def _decode_node(self, encoded_node: dict) -> Node:
"""Decode an item believed to be an encoded `Node` object.
Also updates the list of decoded objects.
Args:
encoded_node: The encoded node to decode.
decoded_objects: The list of decoded objects so far.
Returns:
The decoded node and the updated list of decoded objects.
"""

if isinstance(encoded_node, int):
# Using index mapping to replace the object (load from Labels)
return encoded_node, decoded_objects
return encoded_node
elif "py/object" in encoded_node:
decoded_node: Node = decode_object_dict(encoded_node)
decoded_objects.append(decoded_node)
decoded_node: Node = SkeletonDecoder._decode_object_dict(encoded_node)
self.decoded_objects.append(decoded_node)
elif "py/id" in encoded_node:
decoded_node: Node = decode_id(encoded_node["py/id"], decoded_objects)
decoded_node: Node = self._decode_id(encoded_node["py/id"])

return decoded_node, decoded_objects
return decoded_node

def decode_nodes(
encoded_nodes: List[dict], decoded_objects: List[Union[Node, EdgeType]]
) -> Tuple[List[Dict[str, Node]], List[Union[Node, EdgeType]]]:
def _decode_nodes(self, encoded_nodes: List[dict]) -> List[Dict[str, Node]]:
"""Decode the 'nodes' key in the serialized nx_graph.
The encoded_nodes is a list of dictionary of two types:
Expand All @@ -213,23 +213,20 @@ def decode_nodes(
Args:
encoded_nodes: The list of encoded nodes to decode.
decoded_objects: The list of decoded objects so far.
Returns:
The decoded nodes and the updated list of decoded objects.
The decoded nodes.
"""

decoded_nodes: List[Dict[str, Node]] = []
for e_node_dict in encoded_nodes:
e_node = e_node_dict["id"]
d_node, decoded_objects = decode_node(e_node, decoded_objects)
d_node = self._decode_node(e_node)
decoded_nodes.append({"id": d_node})

return decoded_nodes, decoded_objects
return decoded_nodes

def decode_reduce_dict(
reduce_dict: Dict[str, List[dict]], decoded_objects: List[Union[Node, EdgeType]]
) -> EdgeType:
def _decode_reduce_dict(self, reduce_dict: Dict[str, List[dict]]) -> EdgeType:
"""Decode the 'reduce' key in the serialized nx_graph.
The reduce_dict is a dictionary in the following format:
Expand All @@ -242,7 +239,6 @@ def decode_reduce_dict(
Args:
reduce_dict: The dictionary to decode i.e. reduce_dict = {"py/reduce": ...}
decoded_objects: The list of decoded objects so far.
Returns:
The decoded `EdgeType` object.
Expand Down Expand Up @@ -272,38 +268,30 @@ def decode_reduce_dict(
)

edge = EdgeType(edge_type)
decoded_objects.append(edge)
self.decoded_objects.append(edge)

return edge, decoded_objects
return edge

def decode_edge_type(
encoded_edge_type: dict, decoded_objects: List[Union[Node, EdgeType]]
) -> Tuple[EdgeType, List[Union[Node, EdgeType]]]:
def _decode_edge_type(self, encoded_edge_type: dict) -> EdgeType:
"""Decode the 'type' key in the serialized nx_graph.
Args:
encoded_edge_type: a dictionary with either 'py/id' or 'py/reduce' key.
decoded_objects: The list of decoded objects so far.
Returns:
A tuple including the decoded `EdgeType` object and the updated list of
decoded objects.
The decoded `EdgeType` object.
"""

if "py/reduce" in encoded_edge_type:
edge_type, decoded_objects = decode_reduce_dict(
encoded_edge_type, decoded_objects=decoded_objects
)
edge_type = self._decode_reduce_dict(encoded_edge_type)
else:
# Expect a "py/id" instead of "py/reduce"
edge_type = decode_id(encoded_edge_type["py/id"], decoded_objects)
return edge_type, decoded_objects

def decode_links(
links: List[dict], decoded_objects: List[Union[Node, EdgeType]]
) -> Tuple[
List[Dict[str, Union[int, Node, EdgeType]]], List[Union[Node, EdgeType]]
]:
edge_type = self._decode_id(encoded_edge_type["py/id"])
return edge_type

def _decode_links(
self, links: List[dict]
) -> List[Dict[str, Union[int, Node, EdgeType]]]:
"""Decode the 'links' key in the serialized nx_graph.
The links are the edges in the graph and will have the following keys:
Expand All @@ -314,51 +302,51 @@ def decode_links(
Args:
encoded_links: The list of encoded links to decode.
decoded_objects: The list of decoded objects so far.
"""

for link in links:
for key, value in link.items():
if key == "source":
link[key], decoded_objects = decode_node(value, decoded_objects)
link[key] = self._decode_node(value)
elif key == "target":
link[key], decoded_objects = decode_node(value, decoded_objects)
link[key] = self._decode_node(value)
elif key == "type":
link[key], decoded_objects = decode_edge_type(
value, decoded_objects
)
link[key] = self._decode_edge_type(value)

return links, decoded_objects
return links

dicts = json.loads(json_str)
def _decode(self, json_str: str):
dicts = json.loads(json_str)

# Enforce same format across template and non-template skeletons
if "nx_graph" not in dicts:
# Non-template skeletons use the dicts as the "nx_graph"
dicts = {"nx_graph": dicts}
# Enforce same format across template and non-template skeletons
if "nx_graph" not in dicts:
# Non-template skeletons use the dicts as the "nx_graph"
dicts = {"nx_graph": dicts}

# Decode the graph
nx_graph = dicts["nx_graph"]
# Decode the graph
nx_graph = dicts["nx_graph"]

decoded_objects = []
for key, value in nx_graph.items():
if key == "nodes":
nx_graph[key], decoded_objects = decode_nodes(
value, decoded_objects=decoded_objects
)
elif key == "links":
nx_graph[key], decoded_objects = decode_links(
value, decoded_objects=decoded_objects
self.decoded_objects = [] # Reset the decoded objects incase reusing decoder
for key, value in nx_graph.items():
if key == "nodes":
nx_graph[key] = self._decode_nodes(value)
elif key == "links":
nx_graph[key] = self._decode_links(value)

# Decode the preview image (if it exists)
preview_image = dicts.get("preview_image", None)
if preview_image is not None:
dicts["preview_image"] = decode_preview_image(
preview_image["py/b64"], return_bytes=True
)

# Decode the preview image (if it exists)
preview_image = dicts.get("preview_image", None)
if preview_image is not None:
dicts["preview_image"] = decode_preview_image(
preview_image["py/b64"], return_bytes=True
)
return dicts

return dicts
@classmethod
def decode(cls, json_str: str) -> Dict:
"""Decode the given json string into a dictionary."""
decoder = cls()
return decoder._decode(json_str)


class Skeleton:
Expand Down Expand Up @@ -1347,7 +1335,7 @@ def from_json(
Returns:
An instance of the `Skeleton` object decoded from the JSON.
"""
dicts: dict = replace_jsonpickle_decode(json_str)
dicts: dict = SkeletonDecoder.decode(json_str)
nx_graph = dicts.get("nx_graph", dicts)
graph = json_graph.node_link_graph(nx_graph)

Expand Down Expand Up @@ -1557,8 +1545,3 @@ def __hash__(self):

cattr.register_unstructure_hook(Skeleton, lambda skeleton: Skeleton.to_dict(skeleton))
cattr.register_structure_hook(Skeleton, lambda dicts, cls: Skeleton.from_dict(dicts))

if __name__ == "__main__":
ds = "sleap/skeletons/bees.json"
sk = Skeleton.load_json(ds)
print("debug hook")

0 comments on commit 8b2005a

Please sign in to comment.