-
Notifications
You must be signed in to change notification settings - Fork 202
/
llm.py
134 lines (112 loc) · 4.08 KB
/
llm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import abc
import asyncio
from abc import abstractmethod
import math
import tiktoken
import openai
import backoff
class LLM(abc.ABC):
prompt_percent = 0.9
@abstractmethod
def __init__(self):
raise NotImplementedError("Subclasses should implement this!")
@abstractmethod
def infer(self, prompts):
raise NotImplementedError("Subclasses should implement this!")
@abstractmethod
def split_input(self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header):
raise NotImplementedError("Subclasses should implement this!")
class GPT(LLM):
prompt_percent = 0.8
openai_cxn_dict = {
'default': {
'endpoint': "INSERT YOUR AZURE OPENAI ENDPOINT HERE",
'api_key': "INSERT YOUR AZURE OPENAI API KEY HERE",
},
}
deployment_max_length_dict = {
'gpt-4': 8192,
'gpt-4-0314': 8192,
'gpt-4-32k': 32768,
'gpt-35-turbo': 4096,
'gpt-35-turbo-16k': 16385,
}
def __init__(self, model_id):
self.temperature = 0.0
self.top_k = 1
self.encoding = tiktoken.encoding_for_model("-".join(model_id.split("-", 2)[:2]).replace('5', '.5'))
self.openai_api = 'default'
self.model_id = model_id
self.max_length = self.deployment_max_length_dict[model_id]
self.client = openai.AsyncAzureOpenAI(
api_key=self.openai_cxn_dict[self.openai_api]['api_key'],
api_version="2023-12-01-preview",
azure_endpoint=self.openai_cxn_dict[self.openai_api]['endpoint']
)
def gen_messages(self, fixed_instruction, few_shot_examples, input, input_header, output_header):
messages = [
{
"role": "system",
"content": fixed_instruction,
},
]
for example in few_shot_examples:
messages.extend(
[
{
"role": "user",
"content": input_header+'\n'+example['user']+'\n\n'+output_header,
},
{
"role": "assistant",
"content": example['assistant'],
},
]
)
messages.extend(
[
{
"role": "user",
"content": input_header+'\n'+input+'\n\n'+output_header,
},
]
)
return messages
# Define the coroutine for making API calls to GPT
@backoff.on_exception(backoff.expo, openai.RateLimitError)
async def make_api_call_to_gpt(
self,
messages
):
response = await self.client.chat.completions.create(
model=self.model_id,
messages=messages,
temperature=self.temperature,
)
return response.choices[0].message.content
async def dispatch_openai_requests(
self,
messages_list,
):
# Asynchronously call the function for each prompt
tasks = [self.make_api_call_to_gpt(messages) for messages in messages_list]
# Gather and run the tasks concurrently
results = await asyncio.gather(*tasks)
return results
def infer(self,
messages_list,
):
return asyncio.run(self.dispatch_openai_requests(messages_list))
def split_input(self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header):
# Tokenize fixed_prompt
fixed_token_ids = self.encoding.encode(fixed_instruction+' '.join([x['user']+' '+x['assistant'] for x in few_shot_examples]))
# Calculate remaining token length
remaining_token_len = math.ceil((self.prompt_percent*self.max_length)-len(fixed_token_ids))
# Tokenize splittable_input
split_token_ids = self.encoding.encode(splittable_input)
# Split tokenized split_prompt into list of individual inputs strings. Uses tokens to calculate length
split_token_ids_list = [split_token_ids[i:i+remaining_token_len+10] for i in range(0, len(split_token_ids), remaining_token_len)]
split_input_list = [self.encoding.decode(split_token_ids) for split_token_ids in split_token_ids_list]
# Take the fixed_prompt, few_shot_examples, splitted inputs, and input/output headers and generate list of prompt strings.
return [self.gen_messages(fixed_instruction, few_shot_examples, split_input, input_header, output_header) for split_input in split_input_list]