-
Notifications
You must be signed in to change notification settings - Fork 13
/
d2clusters.m
147 lines (116 loc) · 3.27 KB
/
d2clusters.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
function [clusters, labels] = d2clusters( db, k)
%%
% INPUT
% k: number of clusters
% db.stride: size of supports
% db.w: prob of supports
% db.supp: supports
% OUTPUT
% clusters: k convergent d2 clusters
%
global stdoutput IDX;
if nargin == 1
k = 5;
end
n = length(db{1}.stride); % size of total samples
labels = randi(k,1,n);
nROUND = 5;
nphase = length(db);
clusters = cell(k,1);
isload = false;
if exist('clusters_tmp.mat','file')
% restore last computation
load clusters_tmp.mat;
isload = true;
else
% initialization from random samples
centroid_init = randi(n,[k,1]);
for j=1:k
for i=1:nphase
tmps = sum(db{i}.stride(1:centroid_init(j)-1))+1;
strips = tmps:tmps+db{i}.stride(centroid_init(j))-1;
clusters{j}{i}.supp = db{i}.supp(:,strips);
clusters{j}{i}.w = db{i}.w(strips);
end
end
end
% main algorithm of k centroid clustering
for i=1:nROUND+1
fprintf(stdoutput, 'Round %d ... ', i);
if i==1 && isload
% skip
else
% relabel based on distance
D=zeros(k,n,nphase);
for p=1:nphase
for j=1:k
strip = 1;
for idx = 1:n
D(j,idx,nphase) = kantorovich(clusters{j}{p}.supp, clusters{j}{p}.w,...
db{p}.supp(:,strip:strip+db{p}.stride(idx)-1), ...
db{p}.w(strip:strip+db{p}.stride(idx)-1));
strip = strip + db{p}.stride(idx);
end
end
end
coeff=ones(nphase,1);
DC = zeros(k,n);
for p=1:nphase
DC = DC + coeff(p)* D(:,:,p);
end
labelspast = labels;
[~, labels] = min(DC);
fprintf(stdoutput, '%d labels change \n',sum(labelspast ~= labels));
if (i==nROUND+1)||(sum(labelspast ~= labels) == 0)
break;
end
% export rank to each cluster centroid
[~, IDX] = sort(DC,2);
% save result in each round
save clusters_tmp.mat clusters IDX labels
end
% compute the centroids
for j=1:k
fprintf(stdoutput, '\n\t cluster %d - ', j);
clusters{j} = centroid(j, labels, db, clusters{j});
end
end
end
function [c] = centroid( lb, labels, db, c0)
% INPUT
% lb: to compute the multiphase centroid of points with label = lb
% OUTPUT
% c: multiphase centroid of selected points
global stdoutput ctime bufferc;
nphase = length(db);
c = cell(nphase,1);
if nargin == 3
c0 = cell(nphase,1);
end
for i=1:nphase
fprintf(stdoutput, '\n\t\tphase %d: ', i);
warmlabels = getwarm(labels, db{i}.stride);
w = db{i}.w(lb == warmlabels);
supp = db{i}.supp(:,lb==warmlabels);
stride = db{i}.stride(lb == labels);
ctimer = tic;c{i} = centroid_sphADMM(stride, supp, w, c0{i});ctime(1)=toc(ctimer);
%bufferc{1} = c{i};
%ctimer = tic;c{i} = centroid_sphLP(stride{i}, supp{i}, w{i});ctime(2)=toc(ctimer);
%bufferc{2} = c{i};
end
end
function [warmlabels] = getwarm(lbs, stride)
% INPUT
% lbs: [1,2,1,3,2]
% stride: [4,3,2,3,4]
% OUTPUT
% warmlabels: [1 1 1 1 2 2 2 1 1 3 3 3 2 2 2 2]
%
len = sum(stride);
warmlabels = zeros(1,len);
pos=1;
for j=1:length(stride)
warmlabels(pos:pos+stride(j)-1) = lbs(j);
pos = pos + stride(j);
end
end