-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_model.py
103 lines (82 loc) · 3.13 KB
/
run_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import http.server
import json
import socketserver
import textwrap
import torch
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
from transformers.generation.configuration_utils import GenerationConfig
# Model
LLAMA_MODEL = "Llama3.2-1B-Instruct"
# Load the tokenizer and Llama model
tokenizer = PreTrainedTokenizerFast.from_pretrained(
f"models/transformers/{LLAMA_MODEL}"
)
model = LlamaForCausalLM.from_pretrained(f"models/transformers/{LLAMA_MODEL}")
# Model configurations
try:
if torch.cuda.is_available():
model = model.to("cuda")
except torch.OutOfMemoryError:
model = model.to("cpu")
model.generation_config.pad_token_id = tokenizer.pad_token_id
# Generation settings
generation_config = GenerationConfig(
max_new_tokens=128 * 10**2,
max_time=60.0*5,
stop_strings=["<|end_of_text|>", "<|eot_id|>", "<|eom_id|>"],
)
def generate_text(prompt):
"""Generates text based on the provided prompt using the Llama model.
Args:
prompt (str): The input text to guide the generation process.
Returns:
str: The generated text.
"""
model.generation_config.pad_token_id = tokenizer.pad_token_id
inputs = tokenizer(textwrap.dedent(prompt).strip(), return_tensors="pt").to(
model.device
)
outputs = model.generate(
**inputs, generation_config=generation_config, tokenizer=tokenizer
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
return textwrap.dedent(generated_text).strip().replace("<|begin_of_text|>", "")
class RequestHandler(http.server.SimpleHTTPRequestHandler):
"""HTTP request handler for generating text based on prompts."""
def do_POST(self):
"""Handles POST requests to generate text from a prompt.
Reads the JSON body, extracts the prompt, generates text,
and sends back the generated text in JSON format.
"""
# Read the length of the content
content_length = int(self.headers["Content-Length"])
post_data = self.rfile.read(content_length)
# Load data from JSON
try:
data = json.loads(post_data)
prompt = data.get("prompt", "")
except json.decoder.JSONDecodeError:
self.send_error(400, "Invalid JSON")
return
# Generate text based on the prompt
generated_text = generate_text(prompt)
# Prepare the response
response = {"generated_text": generated_text}
try:
# Send the response
self.send_response(200)
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(json.dumps(response).encode("utf-8"))
except (BrokenPipeError, ConnectionResetError):
return
# Define the server port
PORT = 5000
# Create and start the server with a timeout of 120 seconds
try:
with socketserver.TCPServer(("", PORT), RequestHandler) as httpd:
httpd.timeout = 60*5 # Set timeout to 120 seconds
print(f"Server running on port {PORT}")
httpd.serve_forever()
except KeyboardInterrupt:
print("\nServer stopped by user.")