forked from Vectorized/Python-KD-Tree
-
Notifications
You must be signed in to change notification settings - Fork 0
/
kd_tree.py
151 lines (128 loc) · 4.82 KB
/
kd_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
class KDTree(object):
"""
A super short KD-Tree for points...
so concise that you can copypasta into your homework
without arousing suspicion.
This implementation only supports Euclidean distance.
The points can be any array-like type, e.g:
lists, tuples, numpy arrays.
Usage:
1. Make the KD-Tree:
`kd_tree = KDTree(points, dim)`
2. You can then use `get_knn` for k nearest neighbors or
`get_nearest` for the nearest neighbor
points are be a list of points: [[0, 1, 2], [12.3, 4.5, 2.3], ...]
"""
def __init__(self, points, dim, dist_sq_func=None):
"""Makes the KD-Tree for fast lookup.
Parameters
----------
points : list<point>
A list of points.
dim : int
The dimension of the points.
dist_sq_func : function(point, point), optional
A function that returns the squared Euclidean distance
between the two points.
If omitted, it uses the default implementation.
"""
if dist_sq_func is None:
dist_sq_func = lambda a, b: sum((x - b[i]) ** 2
for i, x in enumerate(a))
def make(points, i=0):
if len(points) > 1:
points.sort(key=lambda x: x[i])
i = (i + 1) % dim
m = len(points) >> 1
return [make(points[:m], i), make(points[m + 1:], i),
points[m]]
if len(points) == 1:
return [None, None, points[0]]
def add_point(node, point, i=0):
if node is not None:
dx = node[2][i] - point[i]
for j, c in ((0, dx >= 0), (1, dx < 0)):
if c and node[j] is None:
node[j] = [None, None, point]
elif c:
add_point(node[j], point, (i + 1) % dim)
import heapq
def get_knn(node, point, k, return_dist_sq, heap, i=0, tiebreaker=1):
if node is not None:
dist_sq = dist_sq_func(point, node[2])
dx = node[2][i] - point[i]
if len(heap) < k:
heapq.heappush(heap, (-dist_sq, tiebreaker, node[2]))
elif dist_sq < -heap[0][0]:
heapq.heappushpop(heap, (-dist_sq, tiebreaker, node[2]))
i = (i + 1) % dim
# Goes into the left branch, then the right branch if needed
for b in (dx < 0, dx >= 0)[:1 + (dx * dx < -heap[0][0])]:
get_knn(node[b], point, k, return_dist_sq,
heap, i, (tiebreaker << 1) | b)
if tiebreaker == 1:
return [(-h[0], h[2]) if return_dist_sq else h[2]
for h in sorted(heap)][::-1]
def walk(node):
if node is not None:
for j in 0, 1:
for x in walk(node[j]):
yield x
yield node[2]
self._add_point = add_point
self._get_knn = get_knn
self._root = make(points)
self._walk = walk
def __iter__(self):
return self._walk(self._root)
def add_point(self, point):
"""Adds a point to the kd-tree.
Parameters
----------
point : array-like
The point.
"""
if self._root is None:
self._root = [None, None, point]
else:
self._add_point(self._root, point)
def get_knn(self, point, k, return_dist_sq=True):
"""Returns k nearest neighbors.
Parameters
----------
point : array-like
The point.
k: int
The number of nearest neighbors.
return_dist_sq : boolean
Whether to return the squared Euclidean distances.
Returns
-------
list<array-like>
The nearest neighbors.
If `return_dist_sq` is true, the return will be:
[(dist_sq, point), ...]
else:
[point, ...]
"""
return self._get_knn(self._root, point, k, return_dist_sq, [])
def get_nearest(self, point, return_dist_sq=True):
"""Returns the nearest neighbor.
Parameters
----------
point : array-like
The point.
return_dist_sq : boolean
Whether to return the squared Euclidean distance.
Returns
-------
array-like
The nearest neighbor.
If the tree is empty, returns `None`.
If `return_dist_sq` is true, the return will be:
(dist_sq, point)
else:
point
"""
l = self._get_knn(self._root, point, 1, return_dist_sq, [])
return l[0] if len(l) else None