diff --git a/sleap/skeleton.py b/sleap/skeleton.py index e9801db99..febf05851 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -85,10 +85,10 @@ def matches(self, other: "Node") -> bool: return other.name == self.name and other.weight == self.weight -def replace_jsonpickle_decode(json_str: str) -> Any: - """Replace jsonpickle.decode with own decoder. +class SkeletonDecoder: + """Replace jsonpickle.decode with our own decoder. - This function will decode the following from their encoded format: + This function will decode the following from jsonpickle's encoded format: `Node` objects from { @@ -119,9 +119,10 @@ def replace_jsonpickle_decode(json_str: str) -> Any: to the object with the same reconstruction id (from top to bottom). """ - def decode_id( - id: int, objects: List[Union[Node, EdgeType]] - ) -> Union[Node, EdgeType]: + def __init__(self): + self.decoded_objects: List[Union[Node, EdgeType]] = [] + + def _decode_id(self, id: int) -> Union[Node, EdgeType]: """Decode the object with the given `py/id` value of `id`. Args: @@ -131,9 +132,10 @@ def decode_id( Returns: The object with the given `py/id` value. """ - return objects[id - 1] + return self.decoded_objects[id - 1] - def decode_state(state: dict) -> Node: + @staticmethod + def _decode_state(state: dict) -> Node: """Reconstruct the `Node` object from 'py/state' key in the serialized nx_graph. We support states in either dictionary or tuple format: @@ -157,7 +159,8 @@ def decode_state(state: dict) -> Node: return Node(**state) - def decode_object_dict(object_dict) -> Node: + @staticmethod + def _decode_object_dict(object_dict) -> Node: """Decode dict containing `py/object` key in the serialized nx_graph. Args: @@ -175,17 +178,16 @@ def decode_object_dict(object_dict) -> Node: if object_dict["py/object"] != "sleap.skeleton.Node": raise ValueError("Only 'sleap.skeleton.Node' objects are supported.") - node: Node = decode_state(object_dict["py/state"]) + node: Node = SkeletonDecoder._decode_state(state=object_dict["py/state"]) return node - def decode_node( - encoded_node: dict, decoded_objects: List[Union[Node, EdgeType]] - ) -> Tuple[Node, List[Union[Node, EdgeType]]]: + def _decode_node(self, encoded_node: dict) -> Node: """Decode an item believed to be an encoded `Node` object. + Also updates the list of decoded objects. + Args: encoded_node: The encoded node to decode. - decoded_objects: The list of decoded objects so far. Returns: The decoded node and the updated list of decoded objects. @@ -193,18 +195,16 @@ def decode_node( if isinstance(encoded_node, int): # Using index mapping to replace the object (load from Labels) - return encoded_node, decoded_objects + return encoded_node elif "py/object" in encoded_node: - decoded_node: Node = decode_object_dict(encoded_node) - decoded_objects.append(decoded_node) + decoded_node: Node = SkeletonDecoder._decode_object_dict(encoded_node) + self.decoded_objects.append(decoded_node) elif "py/id" in encoded_node: - decoded_node: Node = decode_id(encoded_node["py/id"], decoded_objects) + decoded_node: Node = self._decode_id(encoded_node["py/id"]) - return decoded_node, decoded_objects + return decoded_node - def decode_nodes( - encoded_nodes: List[dict], decoded_objects: List[Union[Node, EdgeType]] - ) -> Tuple[List[Dict[str, Node]], List[Union[Node, EdgeType]]]: + def _decode_nodes(self, encoded_nodes: List[dict]) -> List[Dict[str, Node]]: """Decode the 'nodes' key in the serialized nx_graph. The encoded_nodes is a list of dictionary of two types: @@ -213,23 +213,20 @@ def decode_nodes( Args: encoded_nodes: The list of encoded nodes to decode. - decoded_objects: The list of decoded objects so far. Returns: - The decoded nodes and the updated list of decoded objects. + The decoded nodes. """ decoded_nodes: List[Dict[str, Node]] = [] for e_node_dict in encoded_nodes: e_node = e_node_dict["id"] - d_node, decoded_objects = decode_node(e_node, decoded_objects) + d_node = self._decode_node(e_node) decoded_nodes.append({"id": d_node}) - return decoded_nodes, decoded_objects + return decoded_nodes - def decode_reduce_dict( - reduce_dict: Dict[str, List[dict]], decoded_objects: List[Union[Node, EdgeType]] - ) -> EdgeType: + def _decode_reduce_dict(self, reduce_dict: Dict[str, List[dict]]) -> EdgeType: """Decode the 'reduce' key in the serialized nx_graph. The reduce_dict is a dictionary in the following format: @@ -242,7 +239,6 @@ def decode_reduce_dict( Args: reduce_dict: The dictionary to decode i.e. reduce_dict = {"py/reduce": ...} - decoded_objects: The list of decoded objects so far. Returns: The decoded `EdgeType` object. @@ -272,38 +268,30 @@ def decode_reduce_dict( ) edge = EdgeType(edge_type) - decoded_objects.append(edge) + self.decoded_objects.append(edge) - return edge, decoded_objects + return edge - def decode_edge_type( - encoded_edge_type: dict, decoded_objects: List[Union[Node, EdgeType]] - ) -> Tuple[EdgeType, List[Union[Node, EdgeType]]]: + def _decode_edge_type(self, encoded_edge_type: dict) -> EdgeType: """Decode the 'type' key in the serialized nx_graph. Args: encoded_edge_type: a dictionary with either 'py/id' or 'py/reduce' key. - decoded_objects: The list of decoded objects so far. Returns: - A tuple including the decoded `EdgeType` object and the updated list of - decoded objects. + The decoded `EdgeType` object. """ if "py/reduce" in encoded_edge_type: - edge_type, decoded_objects = decode_reduce_dict( - encoded_edge_type, decoded_objects=decoded_objects - ) + edge_type = self._decode_reduce_dict(encoded_edge_type) else: # Expect a "py/id" instead of "py/reduce" - edge_type = decode_id(encoded_edge_type["py/id"], decoded_objects) - return edge_type, decoded_objects - - def decode_links( - links: List[dict], decoded_objects: List[Union[Node, EdgeType]] - ) -> Tuple[ - List[Dict[str, Union[int, Node, EdgeType]]], List[Union[Node, EdgeType]] - ]: + edge_type = self._decode_id(encoded_edge_type["py/id"]) + return edge_type + + def _decode_links( + self, links: List[dict] + ) -> List[Dict[str, Union[int, Node, EdgeType]]]: """Decode the 'links' key in the serialized nx_graph. The links are the edges in the graph and will have the following keys: @@ -314,51 +302,51 @@ def decode_links( Args: encoded_links: The list of encoded links to decode. - decoded_objects: The list of decoded objects so far. """ for link in links: for key, value in link.items(): if key == "source": - link[key], decoded_objects = decode_node(value, decoded_objects) + link[key] = self._decode_node(value) elif key == "target": - link[key], decoded_objects = decode_node(value, decoded_objects) + link[key] = self._decode_node(value) elif key == "type": - link[key], decoded_objects = decode_edge_type( - value, decoded_objects - ) + link[key] = self._decode_edge_type(value) - return links, decoded_objects + return links - dicts = json.loads(json_str) + def _decode(self, json_str: str): + dicts = json.loads(json_str) - # Enforce same format across template and non-template skeletons - if "nx_graph" not in dicts: - # Non-template skeletons use the dicts as the "nx_graph" - dicts = {"nx_graph": dicts} + # Enforce same format across template and non-template skeletons + if "nx_graph" not in dicts: + # Non-template skeletons use the dicts as the "nx_graph" + dicts = {"nx_graph": dicts} - # Decode the graph - nx_graph = dicts["nx_graph"] + # Decode the graph + nx_graph = dicts["nx_graph"] - decoded_objects = [] - for key, value in nx_graph.items(): - if key == "nodes": - nx_graph[key], decoded_objects = decode_nodes( - value, decoded_objects=decoded_objects - ) - elif key == "links": - nx_graph[key], decoded_objects = decode_links( - value, decoded_objects=decoded_objects + self.decoded_objects = [] # Reset the decoded objects incase reusing decoder + for key, value in nx_graph.items(): + if key == "nodes": + nx_graph[key] = self._decode_nodes(value) + elif key == "links": + nx_graph[key] = self._decode_links(value) + + # Decode the preview image (if it exists) + preview_image = dicts.get("preview_image", None) + if preview_image is not None: + dicts["preview_image"] = decode_preview_image( + preview_image["py/b64"], return_bytes=True ) - # Decode the preview image (if it exists) - preview_image = dicts.get("preview_image", None) - if preview_image is not None: - dicts["preview_image"] = decode_preview_image( - preview_image["py/b64"], return_bytes=True - ) + return dicts - return dicts + @classmethod + def decode(cls, json_str: str) -> Dict: + """Decode the given json string into a dictionary.""" + decoder = cls() + return decoder._decode(json_str) class Skeleton: @@ -1347,7 +1335,7 @@ def from_json( Returns: An instance of the `Skeleton` object decoded from the JSON. """ - dicts: dict = replace_jsonpickle_decode(json_str) + dicts: dict = SkeletonDecoder.decode(json_str) nx_graph = dicts.get("nx_graph", dicts) graph = json_graph.node_link_graph(nx_graph) @@ -1557,8 +1545,3 @@ def __hash__(self): cattr.register_unstructure_hook(Skeleton, lambda skeleton: Skeleton.to_dict(skeleton)) cattr.register_structure_hook(Skeleton, lambda dicts, cls: Skeleton.from_dict(dicts)) - -if __name__ == "__main__": - ds = "sleap/skeletons/bees.json" - sk = Skeleton.load_json(ds) - print("debug hook")