-
Notifications
You must be signed in to change notification settings - Fork 0
/
chat.py
178 lines (154 loc) · 7.75 KB
/
chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
from email import message
import os
import argparse
from plato.args import str2bool
from plato.args import parse_args
from plato.data.dataset import Dataset
from plato.data.field import BPETextField
from plato.models.generator import Generator
import json
import numpy as np
import paddle.fluid as fluid
import time
from plato.args import str2bool
from plato.data.data_loader import DataLoader
from plato.data.dataset import LazyDataset
from plato.trainer import Trainer
from plato.models.model_base import ModelBase
import plato.modules.parallel as paralle
from fastapi import FastAPI,Query
from enum import Enum
from typing import Optional
from pydantic import BaseModel
text = "START Guo Shaoyun Blessed love __eou__ Blessed love to star Guo Shaoyun __eou__ Guo Shaoyun comment Tvber in my eyes __eou__ Guo Shaoyun date of birth 1970-8-25 __eou__ Guo Shaoyun height 168cm __eou__ Guo Shaoyun Gender female __eou__ Guo Shaoyun occupation performer __eou__ Guo Shaoyun field Star __eou__ Blessed love Comments on time.com Teacher Bai's play is to watch! __eou__ Blessed love Release date information It was shown last month __eou__ Blessed love to star Guo Shaoyun __eou__ Blessed love type Fantasy __eou__ Blessed love field film __eou__ Guo Shaoyun describe a girl from a rich family __eou__ Guo Shaoyun Ancestral home Hong Kong, China __eou__ Guo Shaoyun nation Han nationality Who do you know? __eou__ There's a lot of money there. Who are you talking about? __eou__ Guo Shaoyun, do you know? __eou__ I've heard of her. Mm-hmm! She's also the star of the movie God bless love! Have you seen it?"
text = "START Guo Shaoyun Blessed love __eou__ Blessed love to star Guo Shaoyun __eou__ Guo Shaoyun comment Tvber in my eyes __eou__ Guo Shaoyun date of birth 1970-8-25 __eou__ Guo Shaoyun height 168cm __eou__ Guo Shaoyun Gender female __eou__ Guo Shaoyun occupation performer __eou__ Guo Shaoyun field Star __eou__ Blessed love Comments on time.com Teacher Bai's play is to watch! __eou__ Blessed love Release date information It was shown last month __eou__ Blessed love to star Guo Shaoyun __eou__ Blessed love type Fantasy __eou__ Blessed love field film __eou__ Guo Shaoyun describe a girl from a rich family __eou__ Guo Shaoyun Ancestral home Hong Kong, China __eou__ Guo Shaoyun nation Han nationality Who do you know? __eou__ There's a lot of money there. Who are you talking about? __eou__ Guo Shaoyun, do you know? __eou__ I've heard of her. who is tongbo?"
parser = argparse.ArgumentParser()
parser.add_argument("--do_train", type=str2bool, default=False,
help="Whether to run trainning.")
parser.add_argument("--do_test", type=str2bool, default=False,
help="Whether to run evaluation on the test dataset.")
parser.add_argument("--do_infer", type=str2bool, default=True,
help="Whether to run inference on the test dataset.")
parser.add_argument("--num_infer_batches", type=int, default=None,
help="The number of batches need to infer.\n"
"Stay 'None': infer on entrie test dataset.")
parser.add_argument("--hparams_file", type=str, default=None,
help="Loading hparams setting from file(.json format).")
parser.add_argument("--host", type=str, default=None,
help="Loading hparams setting from file(.json format).")
parser.add_argument("--port", type=str, default=None,
help="Loading hparams setting from file(.json format).")
parser.add_argument("--reload",action='store_true',
help="Loading hparams setting from file(.json format).")
BPETextField.add_cmdline_argument(parser)
Dataset.add_cmdline_argument(parser)
Trainer.add_cmdline_argument(parser)
ModelBase.add_cmdline_argument(parser)
Generator.add_cmdline_argument(parser)
args = parse_args(parser)
app = FastAPI()
bpe = BPETextField(args.BPETextField)
generator = Generator.create(args.Generator, bpe=bpe) # 生成
args.Model.num_token_embeddings = bpe.vocab_size
knowledge, src, judge = text.strip("\n").split("\t")
req = dict(context=src,knowledge=knowledge,judge=judge) #将对话进行分别为knoledge,src,tgt
build_example_fn = bpe.build_example_multi_turn_with_knowledge_topic
collate_fn = bpe.collate_fn_multi_turn_with_knowledge #padding
data= build_example_fn(req)
dataset = Dataset(data)
test_loader = DataLoader(dataset, args.Trainer, collate_fn=collate_fn, is_test=args.do_infer)
print(json.dumps(args,indent=4))
#for s in test_loader:
# print(s)
#print(len(test_loader))
def to_tensor(array):
array = np.expand_dims(array, -1)
return fluid.dygraph.to_variable(array)
#place = fluid.CUDAPlace(0)
if args.use_data_distributed:
place = fluid.CUDAPlace(args.Env().dev_id)
else:
place = fluid.CUDAPlace(0)
#place = fluid.CPUPlace()
place = fluid.CUDAPlace(2)
model =None
trainer =None
with fluid.dygraph.guard(place):
# Construct Model
now = time.time()
model = ModelBase.create("Model", args, generator=generator)
# Construct Trainer
trainer = Trainer(model, to_tensor, args.Trainer)
end = time.time()
print("lasting_time_model:", end-now)
# Inference process
def split(xs, sep, pad):
""" Split id list by separator. """
out, o = [], []
for x in xs:
if x == pad:
continue
if x != sep:
o.append(x)
else:
if len(o) > 0:
out.append(list(o))
o = []
if len(o) > 0:
out.append(list(o))
assert(all(len(o) > 0 for o in out))
return out
def parse_context(batch):
""" Parse context. """
return bpe.denumericalize([split(xs, bpe.eos_id, bpe.pad_id)
for xs in batch.tolist()])
def parse_text(batch):
""" Parse text. """
return bpe.denumericalize(batch.tolist())
infer_parse_dict = {
"src": parse_context,
"tgt": parse_text,
"preds": parse_text
}
class Item(BaseModel):
#定义请求数据的模型
src_token: str
judge_token: str
knowledge: str
class Item_judge(BaseModel):
judge_token:str
knowledge:str
src_token:str
@app.post("/items2/")
async def create_item(item: Item):# Declare it as a parameter
print("start_chat:")
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
now = time.time()
#knowledge, src, judge = text.strip("\n").split("\t")
knowledge, src, judge = item.knowledge,item.src_token,item.judge_token
req = dict(context=src,knowledge=knowledge,judge=judge) #将对话进行分别为knoledge,src,tgt
data= build_example_fn(req)
dataset = Dataset(data)
test_loader = DataLoader(dataset, args.Trainer, collate_fn=collate_fn, is_test=args.do_infer)
msessage = trainer.infer_chat(test_loader, infer_parse_dict, num_batches=args.num_infer_batches)
end = time.time()
print("lasting_time,",end-now)
print(message)
return (msessage)
@app.post("/judge/")
async def create_item(item: Item_judge):# Declare it as a parameter
print("start_chat:")
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
now = time.time()
#knowledge, src, judge = text.strip("\n").split("\t")
knowledge, src, judge = item.knowledge,item.src_token,item.judge_token
req = dict(context=src,knowledge=knowledge,judge=judge) #将对话进行分别为knoledge,src,tgt
data= build_example_fn(req)
dataset = Dataset(data)
test_loader = DataLoader(dataset, args.Trainer, collate_fn=collate_fn, is_test=args.do_infer)
msessage = trainer.infer_chat_judge_topic(test_loader, infer_parse_dict, num_batches=args.num_infer_batches)
end = time.time()
print("lasting_time,",end-now)
return (msessage)