-
Notifications
You must be signed in to change notification settings - Fork 6
/
utils.py
60 lines (52 loc) · 2.34 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
def compute_CE(x):
"""
x shape : (n , n_hidden)
return : output : (n , 1)
"""
return torch.sqrt(torch.sum(torch.square(x[:, 1:] - x[:, :-1]), dim=1))
def compute_similarity(z, centroids, similarity="EUC"):
"""
Function that compute distance between a latent vector z and the clusters centroids.
similarity : can be in [CID,EUC,COR] : euc for euclidean, cor for correlation and CID
for Complexity Invariant Similarity.
z shape : (batch_size, n_hidden)
centroids shape : (n_clusters, n_hidden)
output : (batch_size , n_clusters)
"""
n_clusters, n_hidden = centroids.shape[0], centroids.shape[1]
bs = z.shape[0]
if similarity == "CID":
CE_z = compute_CE(z).unsqueeze(1) # shape (batch_size , 1)
CE_cen = compute_CE(centroids).unsqueeze(0) ## shape (1 , n_clusters )
z = z.unsqueeze(0).expand((n_clusters, bs, n_hidden))
mse = torch.sqrt(torch.sum((z - centroids.unsqueeze(1)) ** 2, dim=2))
CE_z = CE_z.expand((bs, n_clusters)) # (bs , n_clusters)
CE_cen = CE_cen.expand((bs, n_clusters)) # (bs , n_clusters)
CF = torch.max(CE_z, CE_cen) / torch.min(CE_z, CE_cen)
return torch.transpose(mse, 0, 1) * CF
elif similarity == "EUC":
z = z.expand((n_clusters, bs, n_hidden))
mse = torch.sqrt(torch.sum((z - centroids.unsqueeze(1)) ** 2, dim=2))
return torch.transpose(mse, 0, 1)
elif similarity == "COR":
std_z = (
torch.std(z, dim=1).unsqueeze(1).expand((bs, n_clusters))
) ## (bs,n_clusters)
mean_z = (
torch.mean(z, dim=1).unsqueeze(1).expand((bs, n_clusters))
) ## (bs,n_clusters)
std_cen = (
torch.std(centroids, dim=1).unsqueeze(0).expand((bs, n_clusters))
) ## (bs,n_clusters)
mean_cen = (
torch.mean(centroids, dim=1).unsqueeze(0).expand((bs, n_clusters))
) ## (bs,n_clusters)
## covariance
z_expand = z.unsqueeze(1).expand((bs, n_clusters, n_hidden))
cen_expand = centroids.unsqueeze(0).expand((bs, n_clusters, n_hidden))
prod_expec = torch.mean(
z_expand * cen_expand, dim=2
) ## (bs , n_clusters)
pearson_corr = (prod_expec - mean_z * mean_cen) / (std_z * std_cen)
return torch.sqrt(2 * (1 - pearson_corr))