-
Notifications
You must be signed in to change notification settings - Fork 424
/
superpoint_pytorch.py
166 lines (143 loc) · 5.67 KB
/
superpoint_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""PyTorch implementation of the SuperPoint model,
derived from the TensorFlow re-implementation (2018).
Authors: Rémi Pautrat, Paul-Edouard Sarlin
"""
import torch.nn as nn
import torch
from collections import OrderedDict
from types import SimpleNamespace
def sample_descriptors(keypoints, descriptors, s: int = 8):
"""Interpolate descriptors at keypoint locations"""
b, c, h, w = descriptors.shape
keypoints = (keypoints + 0.5) / (keypoints.new_tensor([w, h]) * s)
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
descriptors = torch.nn.functional.grid_sample(
descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", align_corners=False
)
descriptors = torch.nn.functional.normalize(
descriptors.reshape(b, c, -1), p=2, dim=1
)
return descriptors
def batched_nms(scores, nms_radius: int):
assert nms_radius >= 0
def max_pool(x):
return torch.nn.functional.max_pool2d(
x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
)
zeros = torch.zeros_like(scores)
max_mask = scores == max_pool(scores)
for _ in range(2):
supp_mask = max_pool(max_mask.float()) > 0
supp_scores = torch.where(supp_mask, zeros, scores)
new_max_mask = supp_scores == max_pool(supp_scores)
max_mask = max_mask | (new_max_mask & (~supp_mask))
return torch.where(max_mask, scores, zeros)
def select_top_k_keypoints(keypoints, scores, k):
if k >= len(keypoints):
return keypoints, scores
scores, indices = torch.topk(scores, k, dim=0, sorted=True)
return keypoints[indices], scores
class VGGBlock(nn.Sequential):
def __init__(self, c_in, c_out, kernel_size, relu=True):
padding = (kernel_size - 1) // 2
conv = nn.Conv2d(
c_in, c_out, kernel_size=kernel_size, stride=1, padding=padding
)
activation = nn.ReLU(inplace=True) if relu else nn.Identity()
bn = nn.BatchNorm2d(c_out, eps=0.001)
super().__init__(
OrderedDict(
[
("conv", conv),
("activation", activation),
("bn", bn),
]
)
)
class SuperPoint(nn.Module):
default_conf = {
"nms_radius": 4,
"max_num_keypoints": None,
"detection_threshold": 0.005,
"remove_borders": 4,
"descriptor_dim": 256,
"channels": [64, 64, 128, 128, 256],
}
def __init__(self, **conf):
super().__init__()
conf = {**self.default_conf, **conf}
self.conf = SimpleNamespace(**conf)
self.stride = 2 ** (len(self.conf.channels) - 2)
channels = [1, *self.conf.channels[:-1]]
backbone = []
for i, c in enumerate(channels[1:], 1):
layers = [VGGBlock(channels[i - 1], c, 3), VGGBlock(c, c, 3)]
if i < len(channels) - 1:
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
backbone.append(nn.Sequential(*layers))
self.backbone = nn.Sequential(*backbone)
c = self.conf.channels[-1]
self.detector = nn.Sequential(
VGGBlock(channels[-1], c, 3),
VGGBlock(c, self.stride**2 + 1, 1, relu=False),
)
self.descriptor = nn.Sequential(
VGGBlock(channels[-1], c, 3),
VGGBlock(c, self.conf.descriptor_dim, 1, relu=False),
)
def forward(self, data):
image = data["image"]
if image.shape[1] == 3: # RGB to gray
scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
image = (image * scale).sum(1, keepdim=True)
features = self.backbone(image)
descriptors_dense = torch.nn.functional.normalize(
self.descriptor(features), p=2, dim=1
)
# Decode the detection scores
scores = self.detector(features)
scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
b, _, h, w = scores.shape
scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, self.stride, self.stride)
scores = scores.permute(0, 1, 3, 2, 4).reshape(
b, h * self.stride, w * self.stride
)
scores = batched_nms(scores, self.conf.nms_radius)
# Discard keypoints near the image borders
if self.conf.remove_borders:
pad = self.conf.remove_borders
scores[:, :pad] = -1
scores[:, :, :pad] = -1
scores[:, -pad:] = -1
scores[:, :, -pad:] = -1
# Extract keypoints
if b > 1:
idxs = torch.where(scores > self.conf.detection_threshold)
mask = idxs[0] == torch.arange(b, device=scores.device)[:, None]
else: # Faster shortcut
scores = scores.squeeze(0)
idxs = torch.where(scores > self.conf.detection_threshold)
# Convert (i, j) to (x, y)
keypoints_all = torch.stack(idxs[-2:], dim=-1).flip(1).float()
scores_all = scores[idxs]
keypoints = []
scores = []
descriptors = []
for i in range(b):
if b > 1:
k = keypoints_all[mask[i]]
s = scores_all[mask[i]]
else:
k = keypoints_all
s = scores_all
if self.conf.max_num_keypoints is not None:
k, s = select_top_k_keypoints(k, s, self.conf.max_num_keypoints)
d = sample_descriptors(k[None], descriptors_dense[i, None], self.stride)
keypoints.append(k)
scores.append(s)
descriptors.append(d.squeeze(0).transpose(0, 1))
return {
"keypoints": keypoints,
"keypoint_scores": scores,
"descriptors": descriptors,
}