-
Notifications
You must be signed in to change notification settings - Fork 1
/
service.py
352 lines (271 loc) · 13.4 KB
/
service.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
import os,sys,inspect
import argparse
import json
import requests
from datetime import datetime
from flask import Flask, jsonify, abort, make_response, request, Response
from flask_cors import CORS
import time
import numpy as np
import threading
import uuid
import re
from modules.sent2vec_reranker import SentenceRanker
# Make Flask application
app = Flask(__name__)
CORS(app)
def get_papers( paper_list, projection = None ):
global args
results = requests.post(
args.paper_database_service_address,
data = json.dumps( {
"paper_list" : paper_list,
"projection" : projection
} ),
headers = {"Content-Type": "application/json", 'Connection': 'close'} ).json()["response"]
return results
"""
current version of request information.
'ranking_id_value': '', 'ranking_id_field': '', 'ranking_id_type': '',
'ranking_collection': '', 'keywords': '', 'ranking_variable': '',
'highlight_source': '', 'viewing_id': '', 'username': ''
"""
def parse_request_for_document_search( request_info ):
"""
Get the ranking source:
1) If the ranking varaible is provided, then use the ranking variable as the ranking source
2) If there is no ranking variable, then check if query paper id information is provided. If so, get the paper content and use it as the ranking source; if not, set ranking source as empty string "".
"""
ranking_variable = request_info.get("ranking_variable", "").strip()
if ranking_variable != "":
ranking_source = ranking_variable
else:
try:
paper_record = get_papers( [ {
"collection":request_info.get("ranking_collection", "").strip(),
"id_field":str(request_info.get("ranking_id_field", "")).strip(),
"id_type":"int",
"id_value":int( str(request_info.get("ranking_id_value", "")).strip() )
} ], {"Title":1, "Content.Abstract_Parsed":1, "Content.Fullbody_Parsed":1} )[0]
assert paper_record is not None and paper_record["Title"] != "Not Available"
ranking_source = " ".join( [ paper_record["Title"] ] + get_sentence_list_from_parsed( paper_record["Content"]["Abstract_Parsed"] + paper_record["Content"]["Fullbody_Parsed"] ))
except:
ranking_source = ""
"""
1) Get the keywords;
2) Normalize the keywords into the standard format.
Note: This normalization code will be changed if the frontend organizes the keywords using a different syntax.
"""
ngram_connector_matcher = re.compile("_(?![pP]arsed)")
keywords = request_info.get("keywords", "").strip()
keywords_list = []
for w in keywords.split("\\t"):
# w = " ".join( w.replace("_"," ").split())
w = " ".join( ngram_connector_matcher.sub(" ", w).split() )
w = w.replace("|","<OR>")
w = w.replace("!","<NOT>")
if w.strip() != "":
keywords_list.append(w.strip())
keywords = "<AND>".join( keywords_list )
"""
Define the default behavior when either ranking source, keywords, or both are missing.
"""
if ranking_source == "" and keywords == "":
print("Warning: Neither ranking source nor keywords are provided!")
elif ranking_source == "" and keywords != "":
ranking_source = keywords.replace( "<OR>", " " ).replace( "<NOT>", " " ).replace( "<AND>", " " )
print("Warning: Only keywords are provided! Using keywords also as ranking source!")
return ranking_source, keywords
def prefetch_kernel(ranking_source, keywords, paper_list, nResults, service_address, thread_i, results, timeout):
try:
res = requests.post( service_address ,
data = json.dumps( {
"ranking_source" : ranking_source,
"keywords":keywords,
"paper_list":paper_list,
"nResults":nResults
} ),
headers = {"Content-Type": "application/json", 'Connection': 'close'},
timeout = timeout
).json()
assert isinstance(res["response"] , list)
res["nMatchingDocuments"] = int(res.get("nMatchingDocuments", 0))
except:
res = {
"response":[],
"nMatchingDocuments":0
}
results[thread_i] = res
def prefetch( ranking_source, keywords, paper_list, nResults, service_address_list, timeout ):
results = {}
threads = []
for thread_i, service_address in enumerate( service_address_list ):
t = threading.Thread( target = prefetch_kernel, args = ( ranking_source, keywords, paper_list, nResults, service_address, thread_i, results, timeout ) )
t.start()
threads.append(t)
for t in threads:
t.join()
prefetched_paper_id_list = []
nMatchingDocuments = 0
for thread_i in results:
prefetched_paper_id_list += results[thread_i]["response"]
nMatchingDocuments += results[thread_i]["nMatchingDocuments"]
return prefetched_paper_id_list, nMatchingDocuments
def remove_duplicate( paper_id_list ):
global sentence_ranker
tic = time.time()
dpc_papers_info = [
" ".join([item.get( "Title","" ).lower()] + [ author["FamilyName"].lower() for author in item.get("Author",[])[:1] ] ) ## only use the first author
for item in get_papers( paper_id_list, { "Title":1, "Author":1 } )]
print("loading paper time:", time.time() - tic)
if len(dpc_papers_info)!=len(paper_id_list):
return paper_id_list
sims, doc_indices = sentence_ranker.rank_sentences( "dummy query for duplicate checking", dpc_papers_info )
doc_indices_wo_duplicates = []
sims_wo_duplicates = []
for pos in range(len(sims)):
if len(sims_wo_duplicates) == 0 or sims[pos] < sims_wo_duplicates[-1]:
doc_indices_wo_duplicates.append( doc_indices[pos] )
sims_wo_duplicates.append( sims[pos] )
return [paper_id_list[idx] for idx in sorted(doc_indices_wo_duplicates) ]
def get_sentence_list_from_parsed( parsed ):
sentence_list = []
for section in parsed:
sentence_list.append(str(section.get( "section_title", "" )))
for para in section.get("section_text",[]):
for sen in para.get("paragraph_text", []):
sentence_list.append( str(sen.get("sentence_text","")) )
return sentence_list
def parse_document( doc_data ):
ngram_set = set()
## Title
title = str(doc_data.get("Title", "")).strip()
## Abstract
abstract_sen_list = get_sentence_list_from_parsed(doc_data.get( "Content", {} ).get( "Abstract_Parsed", [] ))
## Fullbody
fullbody_sen_list = get_sentence_list_from_parsed(doc_data.get( "Content", {} ).get( "Fullbody_Parsed", [] ))
sen_list = [ title ] + abstract_sen_list + fullbody_sen_list
## no need to tokenize here, since it is done internally within sentence ranker
doc_text = " ".join( sen_list )
return doc_text
def rank_based_on_query_to_doc_similarity( paper_id_list, ranking_source, nResults = None ):
global sentence_ranker
if ranking_source.strip() == "":
return paper_id_list
tic = time.time()
paper_content_list = get_papers( paper_id_list )
if len(paper_content_list) != len(paper_id_list):
return paper_id_list
print( "load paper time:", time.time() - tic )
try:
doc_text_list = [ parse_document( paper_content ) for paper_content in paper_content_list ]
_, doc_indices = sentence_ranker.rank_sentences( ranking_source, doc_text_list )
selected_papers_to_be_reranked = [ paper_id_list[idx] for idx in doc_indices ]
if nResults is not None:
selected_papers_to_be_reranked = selected_papers_to_be_reranked[:nResults]
except:
selected_papers_to_be_reranked = paper_id_list
return selected_papers_to_be_reranked
def rerank( paper_list, ranking_source, keywords, nResults, reranking_method ):
global args
res = requests.post( args.document_reranking_service_address ,
data = json.dumps( {
"paper_list" : paper_list,
"ranking_source" : ranking_source,
"keywords" : keywords,
"nResults" : nResults,
"reranking_method": reranking_method
} ),
headers = {"Content-Type": "application/json", 'Connection': 'close'} ).json()["response"]
return res
@app.route('/document-search', methods=['POST'])
def document_search():
"""Document search API route"""
global args, sem
sem.acquire()
try:
if not request.json:
assert False
request_info = request.json
## Start querying
query_start_time = time.time()
ranking_source, keywords = parse_request_for_document_search( request_info )
if "paper_list" not in request_info or not isinstance(request_info["paper_list"] , list):
paper_list = None
else:
paper_list = request_info["paper_list"]
nResults = request_info.get( "nResults", 100 )
prefetch_nResults_per_collection = request_info.get( "prefetch_nResults_per_collection", nResults )
timeout = request_info.get( "timeout", 10 )
requires_removing_duplicates = request_info.get( "requires_removing_duplicates", True )
## rank again based on embedding-based NN search, to get globally closest nResults prefetched candidates
requires_additional_prefetching = request_info.get( "requires_additional_prefetching", True )
requires_reranking = request_info.get( "requires_reranking", True )
reranking_method = request_info.get( "reranking_method", "scibert" )
requires_removing_duplicates = True
requires_additional_prefetching = True
requires_reranking = True
reranking_method = "scibert"
## prefetch results from a list of prefetching document search servers
prefetched_paper_id_list, nMatchingDocuments = prefetch(
ranking_source + " " + keywords.replace( "<OR>", " " ).replace( "<NOT>", " " ).replace( "<AND>", " " ),
keywords, paper_list, prefetch_nResults_per_collection,
args.prefetch_service_address_list,
timeout
)
if requires_removing_duplicates:
## remove duplicate
prefetched_paper_id_list = remove_duplicate( prefetched_paper_id_list )
if requires_additional_prefetching:
## rank again based on embedding-based NN search, to get globally closest nResults prefetched candidates
prefetched_paper_id_list = rank_based_on_query_to_doc_similarity( prefetched_paper_id_list, ranking_source, nResults )
if requires_reranking:
## reranking the results gathered from different servers
selected_papers = rerank( prefetched_paper_id_list, ranking_source, keywords, nResults, reranking_method )
else:
selected_papers = prefetched_paper_id_list
stats={
"DurationTotalSearch":int((time.time() - query_start_time) * 1000),
"nMatchingDocuments": nMatchingDocuments
}
json_out = { "query_id": str( uuid.uuid4() ), "response" : selected_papers, "search_stats":stats}
print("Doc search success.")
except:
sem.release()
abort(400)
sem.release()
return json.dumps(json_out), 201
@app.route('/click_feedback', methods=['POST'])
def click_feedback():
startTime = datetime.now()
if not request.json:
print("no request.json")
abort(400)
if "query_id" not in request.json:
print("no query_id provided!")
abort(400)
else:
query_id = request.json['query_id']
print("query_id: " + query_id)
if "paper_id" not in request.json:
print("no paper_id provided!")
abort(400)
else:
paper_id = request.json['paper_id']
print("paper_id: " + json.dumps(paper_id))
return json.dumps({"response":"Feedback Received!"}), 201
PAPER_DATABASE_SERVICE_ADDRESS = os.getenv("PAPER_DATABASE_SERVICE_ADDRESS")
PREFETCH_SERVICE_ADDRESSES = [ addr.strip() for addr in os.getenv("PREFETCH_SERVICE_ADDRESSES").split(",") ]
RERANK_SERVICE_ADDRESS = os.getenv("RERANK_SERVICE_ADDRESS")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument( "-flask_port", type = int, default = 8060 )
parser.add_argument( "-paper_database_service_address", default = PAPER_DATABASE_SERVICE_ADDRESS )
parser.add_argument( "-prefetch_service_address_list", nargs = "+", default = PREFETCH_SERVICE_ADDRESSES )
parser.add_argument( "-document_reranking_service_address", default = RERANK_SERVICE_ADDRESS )
parser.add_argument( "-embedding_based_ranking_model_path", default = "/app/models/sent2vec/model_256.bin" )
args = parser.parse_args()
sentence_ranker = SentenceRanker( args.embedding_based_ranking_model_path)
sem = threading.Semaphore()
print("\n\nWaiting for requests...")
app.run(host='0.0.0.0', port=args.flask_port, threaded=True, debug=True)