-
Notifications
You must be signed in to change notification settings - Fork 138
/
app.py
201 lines (149 loc) · 5.91 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import ollama
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_community.document_loaders import BSHTMLLoader
from langchain_ollama.chat_models import ChatOllama
from langchain_ollama.embeddings import OllamaEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from langchain_pinecone import PineconeVectorStore
import logging
import os
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
import uvicorn
from opentelemetry import trace
from opentelemetry.sdk.resources import Resource
from traceloop.sdk import Traceloop
from traceloop.sdk.decorators import workflow, task
# disable traceloop telemetry
os.environ["TRACELOOP_TELEMETRY"] = "false"
def read_token():
return os.environ.get("API_TOKEN", read_secret("token"))
def read_endpoint():
return os.environ.get("OTEL_ENDPOINT", read_secret("endpoint"))
def read_pinecone_key():
return read_secret("api-key")
def read_secret(secret: str):
try:
with open(f"/etc/secrets/{secret}", "r") as f:
return f.read().rstrip()
except Exception as e:
print(f"No {secret} was provided")
print(e)
return ""
OTEL_ENDPOINT = read_endpoint()
if OTEL_ENDPOINT.endswith("/v1/traces"):
OTEL_ENDPOINT = OTEL_ENDPOINT[: OTEL_ENDPOINT.find("/v1/traces")]
OLLAMA_ENDPOINT = os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434")
# GLOBALS
AI_MODEL = os.environ.get("AI_MODEL", "orca-mini:3b")
AI_SYSTEM = "llama"
AI_EMBEDDING_MODEL = os.environ.get("AI_EMBEDDING_MODEL", "orca-mini:3b")
MAX_PROMPT_LENGTH = 50
retrieval_chain = None
# Initialise the logger
logging.basicConfig(level=logging.INFO, filename="run.log")
logger = logging.getLogger(__name__)
# ################
# # CONFIGURE OPENTELEMETRY
resource = Resource.create(
{"service.name": "travel-advisor", "service.version": "0.2.1"}
)
TOKEN = read_token()
headers = {"Authorization": f"Api-Token {TOKEN}"}
otel_tracer = trace.get_tracer("travel-advisor")
Traceloop.init(
app_name="travel-advisor",
api_endpoint=OTEL_ENDPOINT,
disable_batch=True,
headers=headers,
)
def prep_system():
# Create the embedding
embeddings = OllamaEmbeddings(model=AI_EMBEDDING_MODEL, base_url=OLLAMA_ENDPOINT)
# Retrieve the source data
docs_list = []
for item in os.listdir(path="destinations"):
if item.endswith(".html"):
item_docs_list = BSHTMLLoader(file_path=f"destinations/{item}").load()
for item in item_docs_list:
docs_list.append(item)
# Split Document into tokens
text_splitter = RecursiveCharacterTextSplitter()
documents = text_splitter.split_documents(docs_list)
logger.info("Loading documents from PineCone...")
vector = PineconeVectorStore.from_documents(
documents, index_name="travel-advisor", embedding=embeddings
)
retriever = vector.as_retriever()
logger.info("Initialising Llama LLM...")
llm = ChatOllama(model=AI_MODEL, base_url=OLLAMA_ENDPOINT)
prompt = ChatPromptTemplate.from_template(
"""Answer the following question based only on the provided context:
<context>
{context}
</context>
Question: Give travel advise in a paragraph of max 50 words about {input}
"""
)
document_prompt = PromptTemplate(
input_variables=["page_content", "source"],
template="content:{page_content}\nsource:{source}",
)
document_chain = create_stuff_documents_chain(
llm=llm,
prompt=prompt,
document_prompt=document_prompt,
)
return create_retrieval_chain(retriever, document_chain)
############
# CONFIGURE ENDPOINTS
app = FastAPI()
####################################
@app.get("/api/v1/completion")
def submit_completion(framework: str, prompt: str):
with otel_tracer.start_as_current_span(name="/api/v1/completion") as span:
if framework == "llm":
return llm_chat(prompt, span)
if framework == "rag":
return submit_completion(prompt, span)
return {"message": "invalid Mode"}
@task(name="ollama_chat")
def llm_chat(prompt: str, span):
prompt = f"Give travel advise in a paragraph of max 50 words about {prompt}"
res = ollama.generate(model=AI_MODEL, prompt=prompt)
return {"message": res.get("response")}
@workflow(name="travelgenerator")
def submit_completion(prompt: str, span):
if prompt:
logger.info(f"Calling RAG to get the answer to the question: {prompt}...")
response = retrieval_chain.invoke({"input": prompt}, config={})
# Log information for DQL to grab
logger.info(
f"Response: {response}. Using RAG. model={AI_MODEL}. prompt={prompt}"
)
return {"message": response["answer"]}
else: # No, or invalid prompt given
span.add_event(
f"No prompt provided or prompt too long (over {MAX_PROMPT_LENGTH} chars)"
)
return {
"message": f"No prompt provided or prompt too long (over {MAX_PROMPT_LENGTH} chars)"
}
####################################
@app.get("/api/v1/thumbsUp")
@otel_tracer.start_as_current_span("/api/v1/thumbsUp")
def thumbs_up(prompt: str):
logger.info(f"Positive user feedback for search term: {prompt}")
@app.get("/api/v1/thumbsDown")
@otel_tracer.start_as_current_span("/api/v1/thumbsDown")
def thumbs_down(prompt: str):
logger.info(f"Negative user feedback for search term: {prompt}")
if __name__ == "__main__":
retrieval_chain = prep_system()
# Mount static files at the root
app.mount("/", StaticFiles(directory="./public", html=True), name="public")
# app.mount("/destinations", StaticFiles(directory="destinations", html = True), name="destinations")
# Run the app using uvicorn
uvicorn.run(app, host="0.0.0.0", port=8080)