-
Notifications
You must be signed in to change notification settings - Fork 0
/
examples_toy_dataset.py
103 lines (80 loc) · 2.5 KB
/
examples_toy_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import numpy as np
from numpy import random as rnd
from sklearn.decomposition import PCA
from otpca import ot_pca_bcd, ot_pca_auto_diff
import pylab as plt
from utils import create_directory, save_figure
def main():
folder_path = create_directory('toy_dataset')
rnd.seed(123)
n = 100
d = 20
X = 0.1*rnd.randn(n, d)
X[:33, 0] += 1
X[33:66, 1] += -1
X = X-X.mean(0)
reg = 0.1
max_iter_sink = 100
max_iter_MM = 20
lr = 1e-1
Gbcd, Pbcd, log_bcd = ot_pca_bcd(
X, k=2, method='MM',
reg=reg, verbose=False,
max_iter_sink=max_iter_sink,
max_iter_MM=max_iter_MM
)
Gauto, Pauto, log_auto = ot_pca_auto_diff(
X, k=2, reg=reg, lr=lr, max_iter=100,
max_iter_sink=max_iter_sink,
verbose=False, log=True,
device='cpu'
)
pca = PCA(n_components=2)
pca.fit(X)
xpca = pca.transform(X)
xspca_bcd = X.dot(Pbcd)
xspca2_bcd = n*Gbcd.dot(X.dot(Pbcd))
xspca_auto = X.dot(Pauto)
xspca2_auto = n*Gauto.dot(X.dot(Pauto))
# plot PCA, OT PCA, shrinked OT PCA
plt.figure(1, (15, 7))
plt.subplot(2, 3, 1)
plt.scatter(xpca[:, 0], xpca[:, 1])
plt.title('PCA')
plt.subplot(2, 3, 2)
plt.scatter(xspca_bcd[:, 0], xspca_bcd[:, 1])
plt.title('Sinkhorn PCA (BCD)')
plt.subplot(2, 3, 3)
plt.scatter(xspca2_bcd[:, 0], xspca2_bcd[:, 1])
plt.title('Sinkhorn PCA shrinked (BCD)')
plt.subplot(2, 3, 5)
plt.scatter(xspca_auto[:, 0], xspca_auto[:, 1])
plt.title('OT PCA (Auto-diff)')
plt.subplot(2, 3, 6)
plt.scatter(xspca2_auto[:, 0], xspca2_auto[:, 1])
plt.title('OT PCA shrinked (Auto-diff)')
save_figure(folder_path, 'plot_data_2d')
# plot loss
plt.figure(2)
plt.subplot(1, 2, 1)
plt.imshow(Gbcd)
plt.title('G (BCD)')
plt.subplot(1, 2, 2)
plt.imshow(Gauto)
plt.title('G (Auto-diff)')
plt.figure(3, (15, 7))
ax = plt.subplot(1, 3, 1)
ax.ticklabel_format(useOffset=False)
plt.plot(np.arange(1, len(log_bcd['loss'])+1), log_bcd['loss'])
plt.title('Loss BCD')
ax = plt.subplot(1, 3, 2)
ax.ticklabel_format(useOffset=False)
plt.plot(np.arange(1, len(log_auto['loss'])+1), log_auto['loss'])
plt.title('Loss Riemannian')
ax = plt.subplot(1, 3, 3)
ax.ticklabel_format(useOffset=False)
plt.loglog(np.arange(1, len(log_auto['gradnorm'])+1), log_auto['gradnorm'])
plt.title('Riemannian gradnorm')
save_figure(folder_path, 'plot_loss')
if __name__ == '__main__':
main()