-
Notifications
You must be signed in to change notification settings - Fork 1
/
palm.py
65 lines (54 loc) · 1.91 KB
/
palm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from google.auth import credentials
from google.oauth2 import service_account
import google.cloud.aiplatform as aiplatform
from vertexai.preview.language_models import ChatModel, InputOutputTextPair
import vertexai
import json # add this line
# Load the service account json file
# Update the values in the json file with your own
p_service_acct = r"/etc/secrets/service_account.json"
with open(
p_service_acct
) as f: # replace 'serviceAccount.json' with the path to your file if necessary
service_account_info = json.load(f)
my_credentials = service_account.Credentials.from_service_account_info(
service_account_info
)
# Initialize Google AI Platform with project details and credentials
aiplatform.init(
credentials=my_credentials,
)
with open(p_service_acct, encoding="utf-8") as f:
project_json = json.load(f)
project_id = project_json["project_id"]
# Initialize Vertex AI with project and location
vertexai.init(project=project_id, location="us-central1")
# Initialize the FastAPI application
app = FastAPI()
# Chat with the model
async def handle_chat(human_msg: str):
"""
Endpoint to handle chat.
Receives a message from the user, processes it, and returns a response from the model.
"""
chat_model = ChatModel.from_pretrained("chat-bison@001")
parameters = {
"temperature": 0.8,
"max_output_tokens": 1024,
"top_p": 0.8,
"top_k": 40,
}
chat = chat_model.start_chat()
# Send the human message to the model and get a response
response = chat.send_message(human_msg, **parameters)
# Return the model's response
return {"response": response.text}
async def main(prompt):
resp = await handle_chat(prompt)
print(resp["response"])
if __name__ == "__main__":
import asyncio
prompt = input("Enter your prompt")
asyncio.run(main(prompt))