-
Notifications
You must be signed in to change notification settings - Fork 2
/
kmeans.py
159 lines (144 loc) · 4.97 KB
/
kmeans.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#!/usr/bin/env python
import getopt, sys
from neo4j import GraphDatabase
from sklearn.cluster import KMeans
import numpy as np
# Global defaults, some based on our demo and some on the algo defaults.
DEFAULT_URI = "bolt://localhost:7687"
DEFAULT_USER = "neo4j"
DEFAULT_PASS = "password"
DEFAULT_REL = "UNWEIGHTED_APPEARED_WITH"
DEFAULT_LABEL = "Character"
DEFAULT_PROP = "clusterId"
DEFAULT_P = 1.0
DEFAULT_Q = 1.0
DEFAULT_D = 16
DEFAULT_WALK = 80
DEFAULT_K=6
NODE2VEC_CYPHER = """
CALL gds.alpha.node2vec.stream({
nodeProjection: $L,
relationshipProjection: {
EDGE: {
type: $R,
orientation: 'UNDIRECTED'
}
},
embeddingSize: $d,
returnFactor: $p,
inOutFactor: $q,
walkLength: $l
}) YIELD nodeId, embedding
"""
UPDATE_CYPHER = """
UNWIND $updates AS updateMap
MATCH (n) WHERE id(n) = updateMap.nodeId
SET n += updateMap.valueMap
"""
def usage():
print("usage:\t kmeans.py [-A BOLT URI default: {}] [-U USERNAME (default: {})] [-P PASSWORD (default: {})]"
.format(DEFAULT_URI, DEFAULT_USER, DEFAULT_PASS))
print("supported parameters:")
print("\t-R RELATIONSHIP_TYPE (default: '{}'".format(DEFAULT_REL))
print("\t-L NODE_LABEL (default: '{}'".format(DEFAULT_LABEL))
print("\t-C CLUSTER_PROPERTY (default: '{}'".format(DEFAULT_PROP))
print("\t-d DIMENSIONS (default: {})".format(DEFAULT_D))
print("\t-p RETURN PARAMETER (default: {})".format(DEFAULT_P))
print("\t-q IN-OUT PARAMETER (default: {})".format(DEFAULT_Q))
print("\t-k K-MEANS NUM_CLUSTERS (default: {})".format(DEFAULT_K))
print("\t-l WALK_LENGTH (default: {})".format(DEFAULT_WALK))
sys.exit(1)
def extract_embeddings(driver, label=DEFAULT_LABEL, relType=DEFAULT_REL,
p=DEFAULT_P, q=DEFAULT_Q, d=DEFAULT_D, l=DEFAULT_WALK):
"""
Call the GDS neo2vec routine using the given driver and provided params.
"""
print("Generating graph embeddings (p={}, q={}, d={}, l={}, label={}, relType={})"
.format(p, q, d, l, label, relType))
embeddings = []
with driver.session() as session:
results = session.run(NODE2VEC_CYPHER, L=label, R=relType,
p=float(p), q=float(q), d=int(d), l=int(l))
for result in results:
embeddings.append(result)
print("...generated {} embeddings".format(len(embeddings)))
return embeddings
def kmeans(embeddings, k=DEFAULT_K, clusterProp=DEFAULT_PROP):
"""
Given a list of dicts like {"nodeId" 1, "embedding": [1.0, 0.1, ...]},
generate a list of dicts like {"nodeId": 1, "valueMap": {"clusterId": 2}}
"""
print("Performing K-Means clustering (k={}, clusterProp='{}')"
.format(k, clusterProp))
X = np.array([e["embedding"] for e in embeddings])
kmeans = KMeans(n_clusters=int(k)).fit(X)
results = []
for idx, cluster in enumerate(kmeans.predict(X)):
results.append({ "nodeId": embeddings[idx]["nodeId"],
"valueMap": { clusterProp: int(cluster) }})
print("...clustering completed.")
return results
def _update_tx(tx, cypher, **kwargs):
result = tx.run(cypher, kwargs)
return result.consume()
def update_clusters(driver, clusterResults):
"""
Given a list of dicts with "nodeId" string and a "valueMap" dict, update
the graph by setting the properties from the "valueMap" on each node.
"""
print("Updating graph...")
with driver.session() as session:
result = session.write_transaction(_update_tx, UPDATE_CYPHER, updates=clusterResults)
print("...update complete: {}".format(result.counters))
if __name__ == '__main__':
# getopt, because: "POSIX getopt(1) is The Correct Way" ~sircmpwn
try:
opts, args = getopt.getopt(sys.argv[1:], "hA:U:P:C:R:l:L:p:q:d:k:")
except getopt.GetoptError as err:
print(err)
usage()
uri = DEFAULT_URI
user = DEFAULT_USER
password = DEFAULT_PASS
relType = DEFAULT_REL
label = DEFAULT_LABEL
clusterProp = DEFAULT_PROP
p = DEFAULT_P
q = DEFAULT_Q
d = DEFAULT_D
k = DEFAULT_K
l = DEFAULT_WALK
for o, a in opts:
if o == "-h":
usage()
elif o == "-A":
uri = a
elif o == "-U":
user = a
elif o == "-P":
password = a
elif o == "-R":
relType = a
elif o == "-L":
label = a
elif o == "-C":
clusterProp = a
elif o == "-p":
p = a
elif o == "-q":
q = a
elif o == "-d":
d = a
elif o == "-k":
k = a
elif o == "-l":
l = a
else:
usage()
print("Connecting to uri: {}".format(uri))
driver = GraphDatabase.driver(uri, auth=(user, password))
embeddings = extract_embeddings(driver, label=label, relType=relType,
p=p, q=q, d=d, l=l)
clusters = kmeans(embeddings, k=k, clusterProp=clusterProp)
update_clusters(driver, clusters)
driver.close()