Skip to content

Commit

Permalink
feat: add flag to visualize autogenerated masks
Browse files Browse the repository at this point in the history
  • Loading branch information
SauravMaheshkar committed Aug 4, 2024
1 parent 31f4c41 commit 3f262d9
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions sam2/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,41 @@ def show_masks(
alpha: Optional[float] = 0.5,
display_image: Optional[bool] = False,
only_best: Optional[bool] = True,
autogenerated_mask: Optional[bool] = False,
) -> Image.Image:
if scores is not None:
# sort masks by their scores
sorted_ind = np.argsort(scores)[::-1]
masks = masks[sorted_ind]

# get mask dimensions
h, w = masks.shape[-2:]
if autogenerated_mask:
masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
else:
# get mask dimensions
h, w = masks.shape[-2:]

if display_image:
output_image = Image.fromarray(image)
else:
# create a new blank image to superimpose masks
output_image = Image.new(mode="RGBA", size=(w, h), color=(0, 0, 0))
if autogenerated_mask:
output_image = Image.new(
mode="RGBA",
size=(
masks[0]["segmentation"].shape[0],
masks[0]["segmentation"].shape[1],
),
color=(0, 0, 0),
)
else:
output_image = Image.new(mode="RGBA", size=(w, h), color=(0, 0, 0))

for i, mask in enumerate(masks):
if mask.ndim > 2: # type: ignore
mask = mask.squeeze() # type: ignore

if not autogenerated_mask:
if mask.ndim > 2: # type: ignore
mask = mask.squeeze() # type: ignore
else:
mask = mask["segmentation"]
# Generate a random color with specified alpha value
color = np.concatenate(
(np.random.randint(0, 256, size=3), [int(alpha * 255)]), axis=0
Expand Down

0 comments on commit 3f262d9

Please sign in to comment.