-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dilated layer takes more than k
neightbours
#96
Comments
The
|
I'm not even sure taking t = torch.tensor(
[
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
[0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
]
)
res = Dilated(k=2, dilation=2)(t)
print(res)
# tensor([[0, 0, 0, 1, 1, 2, 2, 2],
# [0, 2, 4, 1, 3, 0, 2, 4]]) For node 1 |
We should first build a knn graph that has |
A possible solution would be (using einops): from einops import rearrange
t = torch.tensor(
[
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
[0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
]
)
k = 2
d = 2
u, counts = torch.unique(t[0], return_counts=True)
k_constructed = counts[0] # assume we always find k neighbours. We can give this as a parameter too
res1 = rearrange(t, "e (n2 k_constructed) -> e n2 k_constructed", k_constructed=k_constructed)
# tensor([[[0, 0, 0, 0, 0],
# [1, 1, 1, 1, 1],
# [2, 2, 2, 2, 2]],
# [[0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4]]])
res2 = res1[:, :, ::d] # Res dilated
print(res2)
# tensor([[[0, 0, 0],
# [1, 1, 1],
# [2, 2, 2]],
# [[0, 2, 4],
# [0, 2, 4],
# [0, 2, 4]]])
res3 = res2[:, :, :k] # Take first k neighbours
print(res3)
# tensor([[[0, 0],
# [1, 1],
# [2, 2]],
# [[0, 2],
# [0, 2],
# [0, 2]]])
res4 = rearrange(res3, "e d1 d2 -> e (d1 d2)")
print(res4)
# tensor([[0, 0, 1, 1, 2, 2],
# [0, 2, 0, 2, 0, 2]]) |
Thanks for the suggestion @zademn. That is definitely a good idea if we are dealing with a more complex case. But in our example, we always build knn graphs with |
The
Dilated
layer doesn't take into accountk
. This can lead to taking more neighbours than intended.The text was updated successfully, but these errors were encountered: