diff --git a/sam2/utils/visualization.py b/sam2/utils/visualization.py index 6c4c10e9..c17f9e1e 100644 --- a/sam2/utils/visualization.py +++ b/sam2/utils/visualization.py @@ -11,25 +11,41 @@ def show_masks( alpha: Optional[float] = 0.5, display_image: Optional[bool] = False, only_best: Optional[bool] = True, + autogenerated_mask: Optional[bool] = False, ) -> Image.Image: if scores is not None: # sort masks by their scores sorted_ind = np.argsort(scores)[::-1] masks = masks[sorted_ind] - # get mask dimensions - h, w = masks.shape[-2:] + if autogenerated_mask: + masks = sorted(masks, key=(lambda x: x["area"]), reverse=True) + else: + # get mask dimensions + h, w = masks.shape[-2:] if display_image: output_image = Image.fromarray(image) else: # create a new blank image to superimpose masks - output_image = Image.new(mode="RGBA", size=(w, h), color=(0, 0, 0)) + if autogenerated_mask: + output_image = Image.new( + mode="RGBA", + size=( + masks[0]["segmentation"].shape[0], + masks[0]["segmentation"].shape[1], + ), + color=(0, 0, 0), + ) + else: + output_image = Image.new(mode="RGBA", size=(w, h), color=(0, 0, 0)) for i, mask in enumerate(masks): - if mask.ndim > 2: # type: ignore - mask = mask.squeeze() # type: ignore - + if not autogenerated_mask: + if mask.ndim > 2: # type: ignore + mask = mask.squeeze() # type: ignore + else: + mask = mask["segmentation"] # Generate a random color with specified alpha value color = np.concatenate( (np.random.randint(0, 256, size=3), [int(alpha * 255)]), axis=0