-
Notifications
You must be signed in to change notification settings - Fork 2
/
count_test.py
76 lines (65 loc) · 2.22 KB
/
count_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#!/usr/bin/python
#-*- coding:utf-8 -*-
############################
#File Name: count_test.py
#Author: chi xiao
#Mail:
#Created Time:
############################
from os.path import join
import numpy as np
import keras.backend as K
import tensorflow as tf
from keras.models import Sequential
import json
import pickle
import os
from keras.layers import Dense, Dropout, LSTM, Bidirectional
from settings import cuda_visible_devices, pubs_validate_path, weighted_embedding_path
from count_train import count_model_parameters_path, root_mean_squared_error, root_mean_log_squared_error, paper_feature
os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible_devices
pubs_validate_cluster_path = "./output/validate_cluster_num.json"
def create_model():
model = Sequential()
model.add(Bidirectional(LSTM(64), input_shape=(300, 100)))
model.add(Dropout(0.5))
model.add(Dense(1))
model.compile(loss="msle",
optimizer='rmsprop',
metrics=[root_mean_squared_error, root_mean_log_squared_error])
model.load_weights(count_model_parameters_path)
return model
def test_validate(model,k=300,flatten=False):
print ("predict cluster number ...")
with open(pubs_validate_path,'r') as f:
pubs_validate_dict = json.load(f)
author_paper_dict = {}
for author,papers in pubs_validate_dict.items():
author_paper_dict.setdefault(author,[])
for paper in papers:
author_paper_dict[author].append(paper['id'])
xs = []
names = []
for name in author_paper_dict.keys():
names.append(name)
x = []
items = author_paper_dict[name]
#print(items)
sampled_points = [items[p] for p in np.random.choice(len(items), k, replace=True)]
for p in sampled_points:
x.append(paper_feature[p])
if flatten:
xs.append(np.sum(x, axis=0))
else:
xs.append(np.stack(x))
xs = np.stack(xs)
kk = model.predict(xs)
return names, kk
if __name__=="__main__":
model = create_model()
names,kk = test_validate(model)
kk = map(int,np.squeeze(kk))
result = dict(zip(names,kk))
print (result)
with open(pubs_validate_cluster_path,'w') as f:
json.dump(result,f)