forked from Vision-CAIR/MiniGPT-4
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cli_demo.py
103 lines (85 loc) · 3.52 KB
/
cli_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import argparse
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import gradio as gr
from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Chat, CONV_VISION
# imports modules for registration
from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *
def parse_args():
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
parser.add_argument("--num-beams", type=int, default=2, help="specify the gpu to load the model.")
parser.add_argument("--temperature", type=int, default=0.9, help="specify the gpu to load the model.")
parser.add_argument("--english", type=bool, default=True, help="chinese or english")
parser.add_argument("--prompt-en", type=str, default="can you describe the current picture?", help="Can you describe the current picture?")
parser.add_argument("--prompt-zh", type=str, default="你能描述一下当前的图片?", help="Can you describe the current picture?")
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
args = parser.parse_args()
return args
def setup_seeds(config):
seed = config.run_cfg.seed + get_rank()
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
cudnn.benchmark = False
cudnn.deterministic = True
# ========================================
# Model Initialization
# ========================================
print('Initializing Chat')
args = parse_args()
cfg = Config(args)
model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
print('Initialization Finished')
while True:
if not args.english:
image_path = input("请输入图像路径或URL(回车进入纯文本对话): ")
else:
image_path = input("Please enter the image path or URL (press Enter for plain text conversation): ")
if image_path == 'stop':
break
if len(image_path) > 0:
query = args.prompt_en if args.english else args.prompt_zh
while True:
if query == "clear":
break
if query == "stop":
sys.exit(0)
img_list = []
chat_state = CONV_VISION.copy()
chat.upload_img(image_path, chat_state, img_list)
chat.ask(query, chat_state)
llm_message = chat.answer(
conv=chat_state,
img_list=img_list,
num_beams=args.num_beams,
temperature=args.temperature,
max_new_tokens=300,
max_length=2000
)[0]
# chatbot[-1][1] = llm_message
print("MiniGPT4:"+llm_message)
query = input("user:")