-
Notifications
You must be signed in to change notification settings - Fork 2
/
annotate.py
executable file
·176 lines (158 loc) · 5.81 KB
/
annotate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
from matplotlib.collections import PatchCollection
import mmcv
def color_val_matplotlib(color):
"""Convert various input in BGR order to normalized RGB matplotlib color
tuples,
Args:
color (:obj:`Color`/str/tuple/int/ndarray): Color inputs
Returns:
tuple[float]: A tuple of 3 normalized floats indicating RGB channels.
"""
color = mmcv.color_val(color)
color = [color / 255 for color in color[::-1]]
return tuple(color)
def annotate_image(im,
result,
CLASSES,
score_thr,
thickness=1,
font_size=13,
mask_color=None,
bbox_color=(72, 101, 241),
text_color=(72, 101, 241),
show=False,
out_file=None,
wait_time=1):
if isinstance(result, tuple):
bbox_result, segm_result = result
if isinstance(segm_result, tuple):
segm_result = segm_result[0] # ms rcnn
else:
bbox_result, segm_result = result, None
bboxes = np.vstack(bbox_result)
labels = [
np.full(bbox.shape[0], i, dtype=np.int32)
for i, bbox in enumerate(bbox_result)
]
labels = np.concatenate(labels)
# draw segmentation masks
segms = None
if segm_result is not None and len(labels) > 0: # non empty
segms = mmcv.concat_list(segm_result)
if isinstance(segms[0], torch.Tensor):
segms = torch.stack(segms, dim=0).detach().cpu().numpy()
else:
segms = np.stack(segms, axis=0)
assert bboxes.ndim == 2, \
f' bboxes ndim should be 2, but its ndim is {bboxes.ndim}.'
assert labels.ndim == 1, \
f' labels ndim should be 1, but its ndim is {labels.ndim}.'
assert bboxes.shape[0] == labels.shape[0], \
'bboxes.shape[0] and labels.shape[0] should have the same length.'
assert bboxes.shape[1] == 4 or bboxes.shape[1] == 5, \
f' bboxes.shape[1] should be 4 or 5, but its {bboxes.shape[1]}.'
img = mmcv.imread(im).astype(np.uint8)
if score_thr > 0:
assert bboxes.shape[1] == 5
scores = bboxes[:, -1]
inds = scores > score_thr
bboxes = bboxes[inds, :]
labels = labels[inds]
if segms is not None:
segms = segms[inds, ...]
mask_colors = []
if labels.shape[0] > 0:
if mask_color is None:
# Get random state before set seed, and restore random state later.
# Prevent loss of randomness.
# See: https://github.com/open-mmlab/mmdetection/issues/5844
state = np.random.get_state()
# random color
np.random.seed(42)
mask_colors = [
np.random.randint(0, 256, (1, 3), dtype=np.uint8)
for _ in range(max(labels) + 1)
]
np.random.set_state(state)
else:
# specify color
mask_colors = [
np.array(mmcv.color_val(mask_color)[::-1], dtype=np.uint8)
] * (
max(labels) + 1)
bbox_color = color_val_matplotlib(bbox_color)
text_color = color_val_matplotlib(text_color)
img = mmcv.bgr2rgb(img)
width, height = img.shape[1], img.shape[0]
img = np.ascontiguousarray(img)
EPS = 1e-2
fig = plt.figure('', frameon=False)
plt.title('')
canvas = fig.canvas
dpi = fig.get_dpi()
# add a small EPS to avoid precision lost due to matplotlib's truncation
# (https://github.com/matplotlib/matplotlib/issues/15363)
fig.set_size_inches((width + EPS) / dpi, (height + EPS) / dpi)
# remove white edges by set subplot margin
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
ax = plt.gca()
ax.axis('off')
polygons = []
color = []
for i, (bbox, label) in enumerate(zip(bboxes, labels)):
bbox_int = bbox.astype(np.int32)
poly = [[bbox_int[0], bbox_int[1]], [bbox_int[0], bbox_int[3]],
[bbox_int[2], bbox_int[3]], [bbox_int[2], bbox_int[1]]]
np_poly = np.array(poly).reshape((4, 2))
polygons.append(Polygon(np_poly))
color.append(bbox_color)
label_text = CLASSES[
label] if CLASSES is not None else f'class {label}'
if len(bbox) > 4:
label_text += f'|{bbox[-1]:.02f}'
ax.text(
bbox_int[0],
max(0,bbox_int[1]-int(font_size*1.7)), # bbox pad+1 = 1.7
f'{label_text}',
bbox={
'facecolor': 'black',
'alpha': 0.8,
'pad': 0.7,
'edgecolor': 'none'
},
color=text_color,
fontsize=font_size,
verticalalignment='top',
horizontalalignment='left')
if segms is not None:
color_mask = mask_colors[labels[i]]
mask = segms[i].astype(bool)
img[mask] = img[mask] * 0.5 + color_mask * 0.5
plt.imshow(img)
p = PatchCollection(
polygons, facecolor='none', edgecolors=color, linewidths=thickness)
ax.add_collection(p)
stream, _ = canvas.print_to_buffer()
buffer = np.frombuffer(stream, dtype='uint8')
img_rgba = buffer.reshape(height, width, 4)
rgb, alpha = np.split(img_rgba, [3], axis=2)
img = rgb.astype('uint8')
img = mmcv.rgb2bgr(img)
if show:
# We do not use cv2 for display because in some cases, opencv will
# conflict with Qt, it will output a warning: Current thread
# is not the object's thread. You can refer to
# https://github.com/opencv/opencv-python/issues/46 for details
if wait_time == 0:
plt.show()
else:
plt.show(block=False)
plt.pause(wait_time)
if out_file is not None:
mmcv.imwrite(img, out_file)
plt.close()
return img