-
Notifications
You must be signed in to change notification settings - Fork 0
/
Grid.py
106 lines (79 loc) · 3.55 KB
/
Grid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import numpy as np
import torch
'''
The code is source from:
Shen, Zhijie and Lin, Chunyu and Liao, Kang and Nie, Lang and Zheng, Zishuo and Zhao, Yao,
"PanoFormer: Panorama Transformer for Indoor 360° Depth Estimation", European Conference on Computer Vision, 2022, pp.195-211.
https://github.com/zhijieshen-bjtu/PanoFormer
'''
def genSamplingPattern(h, w, kh, kw, stride=1):
gridGenerator = GridGenerator(h, w, (kh, kw), stride)
LonLatSamplingPattern = gridGenerator.createSamplingPattern()
grid = LonLatSamplingPattern
with torch.no_grad():
grid = torch.FloatTensor(grid)
grid.requires_grad = False
return grid
class GridGenerator:
def __init__(self, height: int, width: int, kernel_size, stride=1):
self.height = height
self.width = width
self.kernel_size = kernel_size # (Kh, Kw)
self.stride = stride # (H, W)
def createSamplingPattern(self):
"""
:return: (1, H*Kh, W*Kw, (Lat, Lon)) sampling pattern
"""
kerX, kerY = self.createKernel() # (Kh, Kw)
# create some values using in generating lat/lon sampling pattern
rho = np.sqrt(kerX ** 2 + kerY ** 2)
Kh, Kw = self.kernel_size
# when the value of rho at center is zero, some lat values explode to `nan`.
if Kh % 2 and Kw % 2:
rho[Kh // 2][Kw // 2] = 1e-8
nu = np.arctan(rho)
cos_nu = np.cos(nu)
sin_nu = np.sin(nu)
stride_h, stride_w = self.stride, self.stride
h_range = np.arange(0, self.height, stride_h)
w_range = np.arange(0, self.width, stride_w)
lat_range = ((h_range / self.height) - 0.5) * np.pi
lon_range = ((w_range / self.width) - 0.5) * (2 * np.pi)
# generate latitude sampling pattern
lat = np.array([
np.arcsin(cos_nu * np.sin(_lat) + kerY * sin_nu * np.cos(_lat) / rho) for _lat in lat_range
]) # (H, Kh, Kw)
lat = np.array([lat for _ in lon_range]) # (W, H, Kh, Kw)
lat = lat.transpose((1, 0, 2, 3)) # (H, W, Kh, Kw)
# generate longitude sampling pattern
lon = np.array([
np.arctan(kerX * sin_nu / (rho * np.cos(_lat) * cos_nu - kerY * np.sin(_lat) * sin_nu)) for _lat in lat_range
]) # (H, Kh, Kw)
lon = np.array([lon + _lon for _lon in lon_range]) # (W, H, Kh, Kw)
lon = lon.transpose((1, 0, 2, 3)) # (H, W, Kh, Kw)
# (radian) -> (index of pixel)
lat = (lat / np.pi + 0.5) * self.height - 0.5
lon = ((lon / (2 * np.pi) + 0.5) * self.width) % self.width - 0.5
# (2, H, W, Kh, Kw) = ((lat, lon), H, W, Kh, Kw)
LatLon = np.stack((lat, lon))
# (H, Kh, W, Kw, 2) = (H, Kh, W, Kw, (lat, lon))
LatLon = LatLon.transpose((1, 2, 3, 4, 0))
# H, W, Kh, Kw, d = LatLon.shape
# LatLon = LatLon.reshape((1, H, W, Kh, Kw, d)) # (1, H, W, Kh, Kw, 2)
return LatLon
def createKernel(self):
"""
:return: (Ky, Kx) kernel pattern
"""
Kh, Kw = self.kernel_size
delta_lat = np.pi / (self.height // self.stride)
delta_lon = 2 * np.pi / (self.width // self.stride)
range_x = np.arange(-(Kw // 2), Kw // 2 + 1)
if not Kw % 2:
range_x = np.delete(range_x, Kw // 2)
range_y = np.arange(-(Kh // 2), Kh // 2 + 1)
if not Kh % 2:
range_y = np.delete(range_y, Kh // 2)
kerX = np.tan(range_x * delta_lon)
kerY = np.tan(range_y * delta_lat) / np.cos(range_y * delta_lon)
return np.meshgrid(kerX, kerY) # (Kh, Kw)