-
Notifications
You must be signed in to change notification settings - Fork 65
/
demo.py
240 lines (195 loc) · 9.87 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import argparse
import json
import torch
from pipeline.tfidf_retriever import TfidfRetriever
from pipeline.graph_retriever import GraphRetriever
from pipeline.reader import Reader
from pipeline.sequential_sentence_selector import SequentialSentenceSelector
import logging
class DisableLogger():
def __enter__(self):
logging.disable(logging.CRITICAL)
def __exit__(self, a, b, c):
logging.disable(logging.NOTSET)
class ODQA:
def __init__(self, args):
self.args = args
device = torch.device("cuda" if torch.cuda.is_available() and not self.args.no_cuda else "cpu")
# TF-IDF Retriever
self.tfidf_retriever = TfidfRetriever(self.args.db_path, self.args.tfidf_path)
# Graph Retriever
self.graph_retriever = GraphRetriever(self.args, device)
# Reader
self.reader = Reader(self.args, device)
# Supporting facts selector
self.sequential_sentence_selector = SequentialSentenceSelector(self.args, device)
def predict(self,
questions: list):
print('-- Retrieving paragraphs by TF-IDF...', flush=True)
tfidf_retrieval_output = []
for i in range(len(questions)):
question = questions[i]
tfidf_retrieval_output += self.tfidf_retriever.get_abstract_tfidf('DEMO_{}'.format(i), question, self.args)
print('-- Running the graph-based recurrent retriever model...', flush=True)
graph_retrieval_output = self.graph_retriever.predict(tfidf_retrieval_output, self.tfidf_retriever, self.args)
print('-- Running the reader model...', flush=True)
answer, title = self.reader.predict(graph_retrieval_output, self.args)
reader_output = [{'q_id': s['q_id'],
'question': s['question'],
'answer': answer[s['q_id']],
'context': title[s['q_id']]} for s in graph_retrieval_output]
if self.args.sequential_sentence_selector_path is not None:
print('-- Running the supporting facts retriever...', flush=True)
supporting_facts = self.sequential_sentence_selector.predict(reader_output, self.tfidf_retriever, self.args)
else:
supporting_facts = []
return tfidf_retrieval_output, graph_retrieval_output, reader_output, supporting_facts
def main():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--graph_retriever_path",
default=None,
type=str,
required=True,
help="Graph retriever model path.")
parser.add_argument("--reader_path",
default=None,
type=str,
required=True,
help="Reader model path.")
parser.add_argument("--tfidf_path",
default=None,
type=str,
required=True,
help="TF-IDF path.")
parser.add_argument("--db_path",
default=None,
type=str,
required=True,
help="DB path.")
## Other parameters
parser.add_argument("--sequential_sentence_selector_path",
default=None,
type=str,
help="Supporting facts model path.")
parser.add_argument("--max_sent_num",
default=30,
type=int)
parser.add_argument("--max_sf_num",
default=15,
type=int)
parser.add_argument("--bert_model_graph_retriever", default='bert-base-uncased', type=str,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
"bert-base-multilingual-cased, bert-base-chinese.")
parser.add_argument("--bert_model_sequential_sentence_selector", default='bert-large-uncased', type=str,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
"bert-base-multilingual-cased, bert-base-chinese.")
parser.add_argument("--max_seq_length",
default=378,
type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")
parser.add_argument("--max_seq_length_sequential_sentence_selector",
default=256,
type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")
parser.add_argument("--do_lower_case",
action='store_true',
help="Set this flag if you are using an uncased model.")
parser.add_argument("--no_cuda",
action='store_true',
help="Whether not to use CUDA when available")
# RNN graph retriever-specific parameters
parser.add_argument("--max_para_num",
default=10,
type=int)
parser.add_argument('--eval_batch_size',
type=int,
default=5,
help="Eval batch size")
parser.add_argument('--beam_graph_retriever',
type=int,
default=1,
help="Beam size for Graph Retriever")
parser.add_argument('--beam_sequential_sentence_selector',
type=int,
default=1,
help="Beam size for Sequential Sentence Selector")
parser.add_argument('--min_select_num',
type=int,
default=1,
help="Minimum number of selected paragraphs")
parser.add_argument('--max_select_num',
type=int,
default=3,
help="Maximum number of selected paragraphs")
parser.add_argument("--no_links",
action='store_true',
help="Whether to omit any links (or in other words, only use TF-IDF-based paragraphs)")
parser.add_argument("--pruning_by_links",
action='store_true',
help="Whether to do pruning by links (and top 1)")
parser.add_argument("--expand_links",
action='store_true',
help="Whether to expand links with paragraphs in the same article (for NQ)")
parser.add_argument('--tfidf_limit',
type=int,
default=None,
help="Whether to limit the number of the initial TF-IDF pool (only for open-domain eval)")
parser.add_argument("--split_chunk", default=100, type=int,
help="Chunk size for BERT encoding at inference time")
parser.add_argument("--eval_chunk", default=500, type=int,
help="Chunk size for inference of graph_retriever")
parser.add_argument("--tagme",
action='store_true',
help="Whether to use tagme at inference")
parser.add_argument('--topk',
type=int,
default=2,
help="Whether to use how many paragraphs from the previous steps")
parser.add_argument("--n_best_size", default=5, type=int,
help="The total number of n-best predictions to generate in the nbest_predictions.json "
"output file.")
parser.add_argument("--max_answer_length", default=30, type=int,
help="The maximum length of an answer that can be generated. This is needed because the start "
"and end predictions are not conditioned on one another.")
parser.add_argument("--max_query_length", default=64, type=int,
help="The maximum number of tokens for the question. Questions longer than this will "
"be truncated to this length.")
parser.add_argument("--doc_stride", default=128, type=int,
help="When splitting up a long document into chunks, how much stride to take between chunks.")
odqa = ODQA(parser.parse_args())
print()
while True:
questions = input('Questions: ')
questions = questions.strip()
if questions == 'q':
break
elif questions == '':
continue
questions = questions.strip().split('|||')
tfidf_retrieval_output, graph_retriever_output, reader_output, supporting_facts = odqa.predict(questions)
if graph_retriever_output is None:
print()
print('Invalid question! "{}"'.format(question))
print()
continue
print()
print('#### Retrieval results ####')
print(json.dumps(graph_retriever_output, indent=4))
print()
print('#### Reader results ####')
print(json.dumps(reader_output, indent=4))
print()
if len(supporting_facts) > 0:
print('#### Supporting facts ####')
print(json.dumps(supporting_facts, indent=4))
print()
if __name__ == "__main__":
with DisableLogger():
main()