forked from gmberton/CosPlace
-
Notifications
You must be signed in to change notification settings - Fork 2
/
dataset_warp.py
205 lines (172 loc) · 9.47 KB
/
dataset_warp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
# Code from https://github.com/gmberton/geo_warp
import torch
import kornia
import os
import random
import torchvision.transforms as T
from PIL import Image
from shapely.geometry import Polygon
import logging
import numpy as np
def open_image(path):
return Image.open(path).convert("RGB")
# DATASET FOR WARPING
def get_random_trapezoid(k=1):
"""Get the points (with shape [4, 2] of a random trapezoid with two vertical sides.
With k=0, the trapezoid is a rectangle with points:
[[-1., -1.], [ 1., -1.], [ 1., 1.], [-1., 1.]]
Parameters
----------
k : float, with 0 <= k <= 1, indicates the "randomness" of the trapezoid.
The higher k, the more "random" the trapezoid.
"""
assert 0 <= k <= 1
def rand(k):
return 1 - (random.random() * k)
left = -rand(k)
right = rand(k)
trap_points = np.empty(shape=(4, 2))
trap_points[0] = (left, -rand(k))
trap_points[1] = (right, -rand(k))
trap_points[2] = (right, rand(k))
trap_points[3] = (left, rand(k))
return trap_points
def compute_warping(model, tensor_img_1, tensor_img_2, weights=None):
"""Computes the pairwise warping of two (batches of) images, using a given model.
Given that the operations in the model is not commutative (i.e. the order of
the tensor matters), this function computes the mean passing the tensor images
in both orders.
Parameters
----------
model : network.Network, used to compute the homography.
tensor_img_1 : torch.Tensor, the query images, with shape [B, 3, H, W].
tensor_img_2 : torch.Tensor, the gallery images, with shape [B, 3, H, W].
weights : torch.Tensor, random weights to avoid numerical instability,
usually they're not needed.
Returns
-------
warped_tensor_img_1 : torch.Tensor, the warped query images, with shape [B, 3, H, W].
warped_tensor_img_2 : torch.Tensor, the warped gallery images, with shape [B, 3, H, W].
mean_pred_points_1 : torch.Tensor, the predicted points, used to compute homography
on the query images, with shape [B, 4, 2]
mean_pred_points_2 : torch.Tensor, the predicted points, used to compute homography
on the gallery images, with shape [B, 4, 2]
"""
# Get both predictions
pred_points_1to2, pred_points_2to1 = model("similarity_and_regression",
[tensor_img_1, tensor_img_2]) # images are predictions
# Average them
mean_pred_points_1 = (pred_points_1to2[:, :4] + pred_points_2to1[:, 4:]) / 2
mean_pred_points_2 = (pred_points_1to2[:, 4:] + pred_points_2to1[:, :4]) / 2
# Apply the homography
warped_tensor_img_1, _ = warp_images(tensor_img_1, mean_pred_points_1, weights) # warp
warped_tensor_img_2, _ = warp_images(tensor_img_2, mean_pred_points_2, weights)
return warped_tensor_img_1, warped_tensor_img_2, mean_pred_points_1, mean_pred_points_2
def warp_images(tensor_img, warping_points, weights=None):
"""Apply the homography to a batch of images using the points specified in warping_points.
Parameters
----------
tensor_img : torch.Tensor, the images, with shape [B, 3, H, W].
warping_points : torch.Tensor, the points used to compute homography, with shape [B, 4, 2]
weights : torch.Tensor, random weights to avoid numerical instability, usually they're not needed.
Returns
-------
warped_images : torch.Tensor, the warped images, with shape [B, 3, H, W].
theta : theta matrix of the homography, usually not needed, with shape [B, 3, 3].
"""
B, C, H, W = tensor_img.shape
assert warping_points.shape == torch.Size([B, 4, 2])
rectangle_points = torch.repeat_interleave(torch.tensor(get_random_trapezoid(k=0)).unsqueeze(0), B, 0)
rectangle_points = rectangle_points.to(tensor_img.device)
# NB for older versions of kornia use kornia.find_homography_dlt
theta = kornia.geometry.homography.find_homography_dlt(rectangle_points.float(), warping_points.float(), weights)
# NB for older versions of kornia use kornia.homography_warp
warped_images = kornia.geometry.homography_warp(tensor_img, theta, dsize=(H, W))
return warped_images, theta
def get_random_homographic_pair(source_img, k, is_debugging=False):
"""Given a source image, returns a pair of warped images generate in a self-supervised
fashion, together with the points used to generate the projections (homography).
Parameters
----------
source_img : torch.Tensor, with shape [3, H, W].
k : float, the k parameter indicates how "difficult" the generated pair is,
i.e. it's an upper bound on how strong the warping can be.
is_debugging : bool, if True return extra information
"""
# Compute two random trapezoids and their intersection
trap_points_1 = get_random_trapezoid(k)
trap_points_2 = get_random_trapezoid(k)
t1 = torch.tensor(trap_points_1)
t2 = torch.tensor(trap_points_2)
points_trapezoids = torch.cat((t1.unsqueeze(0), t2.unsqueeze(0)))
trap_1 = Polygon(trap_points_1)
trap_2 = Polygon(trap_points_2)
intersection = trap_2.intersection(trap_1)
# Some operations to get the intersection points as a torch.Tensor of shape [4, 2]
list_x, list_y = intersection.exterior.coords.xy
a3, d3 = sorted(list(set([(x, y) for x, y in zip(list_x, list_y) if x == min(list_x)])))
b3, c3 = sorted(list(set([(x, y) for x, y in zip(list_x, list_y) if x == max(list_x)])))
intersection_points = torch.tensor([a3, b3, c3, d3]).type(torch.float)
intersection_points = torch.repeat_interleave(intersection_points.unsqueeze(0), 2, 0)
image_repeated_twice = torch.repeat_interleave(source_img.unsqueeze(0), 2, 0)
warped_images, theta = warp_images(image_repeated_twice, points_trapezoids)
theta = torch.inverse(theta)
# Compute positions of intersection_points projected on the warped images
xs = [(theta[:, 0, 0] * intersection_points[:, p, 0] + theta[:, 0, 1] * intersection_points[:, p, 1] + theta[:, 0,
2]) /
(theta[:, 2, 0] * intersection_points[:, p, 0] + theta[:, 2, 1] * intersection_points[:, p, 1] + theta[:, 2,
2]) for p in
range(4)]
ys = [(theta[:, 1, 0] * intersection_points[:, p, 0] + theta[:, 1, 1] * intersection_points[:, p, 1] + theta[:, 1,
2]) /
(theta[:, 2, 0] * intersection_points[:, p, 0] + theta[:, 2, 1] * intersection_points[:, p, 1] + theta[:, 2,
2]) for p in
range(4)]
# Refactor the projected intersection points as a torch.Tensor with shape [2, 4, 2]
warped_intersection_points = torch.cat((torch.stack(xs).T.reshape(2, 4, 1), torch.stack(ys).T.reshape(2, 4, 1)), 2)
if is_debugging:
warped_images_intersection, inverse_theta = warp_images(warped_images, warped_intersection_points)
return (source_img, *warped_images, *warped_images_intersection), (theta, inverse_theta), \
(*points_trapezoids, *intersection_points, *warped_intersection_points)
else:
return warped_images[0], warped_images[1], warped_intersection_points[0], warped_intersection_points[1]
class HomographyDataset(torch.utils.data.Dataset):
def __init__(self, args, dataset_folder, M=10, N=5, current_group=0, min_images_per_class=10, k=0.1,
is_debugging=False):
super().__init__()
self.M = M
self.N = N
self.current_group = current_group
self.dataset_folder = dataset_folder
self.augmentation_device = args.augmentation_device
self.k = k
self.is_debugging = is_debugging
# dataset_name should be either "processed", "small" or "raw", if you're using SF-XL
dataset_name = os.path.basename(args.dataset_folder)
filename = f"cache/{dataset_name}_M{M}_N{N}_mipc{min_images_per_class}.torch"
classes_per_group, self.images_per_class = torch.load(filename)
self.classes_ids = classes_per_group[current_group]
if self.augmentation_device == "cpu":
self.transform = T.Compose([
T.ColorJitter(brightness=args.brightness,
contrast=args.contrast,
saturation=args.saturation,
hue=args.hue),
T.RandomResizedCrop([512, 512], scale=[1 - args.random_resized_crop, 1]),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
self.base_transform = T.Compose([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def __getitem__(self, class_num):
# This function takes as input the class_num instead of the index of
# the image. This way each class is equally represented during warping.
class_id = self.classes_ids[class_num]
image_path = random.choice(self.images_per_class[class_id])
pil_image = open_image(image_path)
tensor_image = self.base_transform(pil_image)
return get_random_homographic_pair(tensor_image, self.k, is_debugging=self.is_debugging)
def __len__(self):
"""Return the number of homography classes within this group."""
return len(self.classes_ids)