diff --git a/sleap/gui/widgets/video.py b/sleap/gui/widgets/video.py index f4d22b829..e897930a2 100644 --- a/sleap/gui/widgets/video.py +++ b/sleap/gui/widgets/video.py @@ -74,6 +74,7 @@ from sleap.prefs import prefs from sleap.skeleton import Node from sleap.io.cameras import Camcorder +from sleap.io.cameras import InstanceGroup class LoadImageWorker(QtCore.QObject): @@ -272,6 +273,7 @@ def update_selection_state(a, b): self.state.connect("frame_idx", lambda idx: self.plot()) self.state.connect("frame_idx", lambda idx: self.seekbar.setValue(idx)) self.state.connect("instance", self.view.selectInstance) + self.state.connect("instance_group", self.view.selectInstance) self.state.connect("show instances", self.plot) self.state.connect("show labels", self.plot) @@ -960,18 +962,42 @@ def all_instances(self) -> List["QtInstance"]: scene_items = self.scene.items(Qt.SortOrder.AscendingOrder) return list(filter(lambda x: isinstance(x, QtInstance), scene_items)) - def selectInstance(self, select: Union[Instance, int]): - """ - Select a particular instance in view. + def selectInstance(self, select: Optional[Union[Instance, int, InstanceGroup]]): + """Select a particular instance in view. Args: - select: Either `Instance` or index of instance in view. + select: Either `None` or `Instance`, index, or `InstanceGroup` of instance + in view. Returns: None """ + + # Decide which function to use to determine if instance is selected + if isinstance(select, int): + + def determine_selected(idx: int, instance: QtInstance): + return idx == select + + elif isinstance(select, Instance): + + def determine_selected(idx: int, instance: QtInstance): + return instance.instance == select + + elif isinstance(select, InstanceGroup): + + def determine_selected(idx: int, instance: QtInstance): + return instance.instance in select.instances + + else: + + def determine_selected(idx: int, instance: QtInstance): + return False + + # Set selected state for each instance for idx, instance in enumerate(self.all_instances): - instance.selected = select == idx or select == instance.instance + instance.selected = determine_selected(idx, instance) + self.updatedSelection.emit() def getSelectionIndex(self) -> Optional[int]: