-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
363 lines (309 loc) · 16.1 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
import arxiv
import aiofiles
import aiofiles.os
import asyncio
import logging
import chainlit as cl
from openai import AsyncOpenAI
from chainlit.context import context
from chainlit.user_session import user_session
from aiohttp import ClientSession
from metadata_pipeline import daily_metadata_task
from langchain_community.document_loaders import PyPDFLoader
from langchain_openai import OpenAIEmbeddings
from langchain_pinecone import PineconeVectorStore
from langchain_text_splitters import RecursiveCharacterTextSplitter
logging.basicConfig(filename='combined_log.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
daily_task_scheduled = False
def initialize_embeddings():
"""Initialize the OpenAI embedding model."""
logger.info("Initializing OpenAI embeddings...")
return OpenAIEmbeddings(model="text-embedding-3-small")
def initialize_vector_stores(embedding_model):
"""Initialize Pinecone vector stores for metadata and chunks."""
logger.info("Initializing Pinecone vector stores...")
metadata_vector_store = PineconeVectorStore.from_existing_index(embedding=embedding_model, index_name="arxiv-rag-metadata")
chunks_vector_store = PineconeVectorStore.from_existing_index(embedding=embedding_model, index_name="arxiv-rag-chunks")
return metadata_vector_store, chunks_vector_store
def initialize_text_splitter():
"""Initialize the recursive character text splitter."""
logger.info("Initializing text splitter...")
return RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=50,
length_function=len,
is_separator_regex=False
)
async def send_actions():
"""Send action options to the user."""
actions = [
cl.Action(name="ask_followup_question", value="followup_question", description="Uses The Previously Retrieved Context", label="Ask a Follow-Up Question"),
cl.Action(name="ask_new_question", value="new_question", description="Retrieves New Context", label="Ask a New Question About the Same Paper"),
cl.Action(name="ask_about_new_paper", value="new_paper", description="Ask About A Different Paper", label="Ask About a Different Paper")
]
await cl.Message(content="### Please Select One of the Following Options", actions=actions).send()
@cl.on_stop
async def on_stop():
"""Handle session stop event to clean up tasks."""
streaming_task = user_session.get('streaming_task')
if streaming_task:
streaming_task.cancel()
await send_actions()
user_session.set('streaming_task', None)
logger.info("Session stopped and streaming task cleaned up.")
@cl.on_chat_start
async def main():
"""Main function to start the chat session."""
global daily_task_scheduled
if not daily_task_scheduled:
asyncio.create_task(daily_metadata_task())
daily_task_scheduled = True
embedding_model = initialize_embeddings()
metadata_vector_store, chunks_vector_store = initialize_vector_stores(embedding_model)
text_splitter = initialize_text_splitter()
user_session.set('embedding_model', embedding_model)
user_session.set('metadata_vector_store', metadata_vector_store)
user_session.set('chunks_vector_store', chunks_vector_store)
user_session.set('text_splitter', text_splitter)
user_session.set('current_document_id', None)
# Start with the initial query
await ask_initial_query(initial=True)
async def ask_initial_query(initial=False):
"""Prompt the user to enter the title of the research paper."""
if initial:
# Combined welcome message and query prompt
text_content = """## Welcome to arXivGPT
arXivGPT helps students, researchers, and enthusiasts by providing real-time access to the latest research uploaded to arXiv.
With daily updates, it ensures users always have the most recent information.
### Instructions
1. **Enter the Title**: Begin by entering the title of the research paper.
2. **Select a Paper**: Choose a paper from the list by entering its number.
3. **Database Check**: The system checks if the paper is in the database.
- If yes, you'll be prompted to enter your question.
- If not, the paper is downloaded, then you'll be prompted to ask your question.
4. **Read the Answer**: After receiving the answer, you can:
- Ask a follow-up question.
- Ask a new question about the same paper.
- Ask about a different paper.
### Get Started
Enter the title of the research paper you wish to learn more about.
"""
res = await cl.AskUserMessage(content=text_content, timeout=3600).send()
else:
res = await cl.AskUserMessage(content="### Please Enter the Title of the Research Paper You Wish to Learn More About", timeout=3600).send()
if res:
initial_query = res['output']
metadata_vector_store = user_session.get('metadata_vector_store')
logger.info(f"Searching for metadata with query: {initial_query}")
search_results = metadata_vector_store.similarity_search(query=initial_query, k=5)
logger.info(f"Metadata search results: {search_results}")
selected_doc_id = await select_document_from_results(search_results)
if selected_doc_id:
logger.info(f"Document selected with ID: {selected_doc_id}")
user_session.set('current_document_id', selected_doc_id)
chunks_exist = await do_chunks_exist_already(selected_doc_id)
if not chunks_exist:
await process_and_upload_chunks(selected_doc_id)
else:
await ask_user_question(selected_doc_id)
async def ask_user_question(document_id):
"""Prompt the user to enter a question about the selected document."""
logger.info(f"Asking user question for document_id: {document_id}")
context, user_query = await process_user_query(document_id)
if context and user_query:
task = asyncio.create_task(query_openai_with_context(context, user_query))
user_session.set('streaming_task', task)
await task
async def select_document_from_results(search_results):
"""Allow user to select a document from the search results."""
if not search_results:
await cl.Message(content="No Search Results Found").send()
return None
message_content = "### Please Enter the Number Corresponding to Your Desired Paper\n"
message_content += "| No. | Paper Title | Doc. ID |\n"
message_content += "|-----|-------------|---------|\n"
for i, doc in enumerate(search_results, start=1):
page_content = doc.page_content
document_id = doc.metadata['document_id']
message_content += f"| {i} | {page_content} | {document_id} |\n"
await cl.Message(content=message_content).send()
while True:
res = await cl.AskUserMessage(content="", timeout=3600).send()
if res:
try:
user_choice = int(res['output']) - 1
if 0 <= user_choice < len(search_results):
selected_doc_id = search_results[user_choice].metadata['document_id']
selected_paper_title = search_results[user_choice].page_content
await cl.Message(content=f"\n**You selected:** {selected_paper_title}").send()
return selected_doc_id
else:
await cl.Message(content="\nInvalid Selection. Please enter a valid number from the list.").send()
except ValueError:
await cl.Message(content="\nInvalid input. Please enter a number.").send()
else:
await cl.Message(content="\nNo selection made. Please enter a valid number from the list.").send()
async def do_chunks_exist_already(document_id):
"""Check if chunks for the document already exist."""
chunks_vector_store = user_session.get('chunks_vector_store')
filter = {"document_id": {"$eq": document_id}}
test_query = chunks_vector_store.similarity_search(query="Chunks Existence Check", k=1, filter=filter)
logger.info(f"Chunks existence check result for document_id {document_id}: {test_query}")
return bool(test_query)
async def download_pdf(session, document_id, url, filename):
"""Download the PDF file asynchronously."""
logger.info(f"Downloading PDF for document_id: {document_id} from URL: {url}")
async with session.get(url) as response:
if response.status == 200:
async with aiofiles.open(filename, mode='wb') as f:
await f.write(await response.read())
logger.info(f"Successfully downloaded PDF for document_id: {document_id}")
else:
logger.error(f"Failed to download PDF for document_id: {document_id}, status code: {response.status}")
raise Exception(f"Failed to download PDF: {response.status}")
async def process_and_upload_chunks(document_id):
"""Download, process, and upload chunks of the document."""
await cl.Message(content="#### It seems that paper isn't currently in our database. Don't worry, we are currently downloading, processing, and uploading it. This will only take a few moments.").send()
await asyncio.sleep(2)
try:
async with ClientSession() as session:
paper = await asyncio.to_thread(next, arxiv.Client().results(arxiv.Search(id_list=[str(document_id)])))
url = paper.pdf_url
filename = f"{document_id}.pdf"
await download_pdf(session, document_id, url, filename)
loader = PyPDFLoader(filename)
pages = await asyncio.to_thread(loader.load)
text_splitter = user_session.get('text_splitter')
content = []
found_references = False
for page in pages:
if found_references:
break
page_text = page.page_content
if "references" in page_text.lower():
content.append(page_text.split("References")[0])
found_references = True
else:
content.append(page_text)
full_content = ''.join(content)
chunks = text_splitter.split_text(full_content)
embedding_model = user_session.get('embedding_model')
if not embedding_model:
raise ValueError("Embedding model not initialized")
chunks_vector_store = user_session.get('chunks_vector_store')
await asyncio.to_thread(
chunks_vector_store.from_texts,
texts=chunks,
embedding=embedding_model,
metadatas=[{"document_id": document_id} for _ in chunks],
index_name="arxiv-rag-chunks"
)
await aiofiles.os.remove(filename)
logger.info(f"Successfully processed and uploaded chunks for document_id: {document_id}")
await ask_user_question(document_id)
except Exception as e:
logger.error(f"Error processing and uploading chunks for document_id {document_id}: {e}")
await cl.Message(content="#### An error occurred during processing. Please try again.").send()
return
async def process_user_query(document_id):
"""Process the user's query about the document."""
res = await cl.AskUserMessage(content="### Please Enter Your Question", timeout=3600).send()
if res:
user_query = res['output']
context = []
chunks_vector_store = user_session.get('chunks_vector_store')
filter = {"document_id": {"$eq": document_id}}
attempts = 5
for attempt in range(attempts):
search_results = chunks_vector_store.similarity_search(query=user_query, k=15, filter=filter)
logger.info(f"Context retrieval attempt {attempt + 1}: Found {len(search_results)} results")
context = [doc.page_content for doc in search_results]
if context:
break
logger.info(f"No context found, retrying... (attempt {attempt + 1}/{attempts})")
await asyncio.sleep(2)
logger.info(f"User query processed. Context length: {len(context)}, User Query: {user_query}")
return context, user_query
return None, None
async def query_openai_with_context(context, user_query):
"""Query OpenAI with the context and user query."""
if not context:
await cl.Message(content="No context available to answer the question.").send()
return
client = AsyncOpenAI()
settings = {
"model": "gpt-4o",
"temperature": 0.3,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
}
message_history = [
{"role": "system", "content": """
Your job is to answer the user's query using only the provided context.
Be detailed and long-winded. Format your responses in markdown formatting, making good use of headings,
subheadings, ordered and unordered lists, and regular text formatting such as bolding of text and italics.
Sometimes the equations retrieved from the context will be formatted improperly and in an incompatible format
for correct LaTeX rendering. Therefore, if you ever need to provide equations, make sure they are
formatted properly using LaTeX. For in-line equations make sure you wrap the equation in single dollar signs $
and for bigger, more visual equations make sure to wrap the equation in double dollar signs. Remember, some of the
LaTeX may be given to you in an incorrect or weird format.
For example, "( \kappa )" wouldn't render inline, however "$( \kappa )$" would render. Similarly "( \ell_{\text{margin}} )"
wouldn't render either while "$( \ell_{\text{margin}} )$" would render. Make sure you render equations correctly.
Don't change the content of the equation, but just fix up what you have been given. Keep your answer grounded in the facts
of the provided context. If the context does not contain the facts needed to answer the user's query, return:
"I do not have enough information available to accurately answer the question."
"""},
{"role": "user", "content": f"Context: {context}"},
{"role": "user", "content": f"Question: {user_query}"}
]
msg = cl.Message(content="")
await msg.send()
async def stream_response():
stream = await client.chat.completions.create(messages=message_history, stream=True, **settings)
async for part in stream:
if token := part.choices[0].delta.content:
await msg.stream_token(token)
streaming_task = asyncio.create_task(stream_response())
user_session.set('streaming_task', streaming_task)
try:
await streaming_task
except asyncio.CancelledError:
streaming_task.cancel()
return
await msg.update()
await send_actions()
@cl.action_callback("ask_followup_question")
async def handle_followup_question(action):
"""Handle follow-up question action."""
logger.info("Follow-up question button clicked.")
current_document_id = user_session.get('current_document_id')
if current_document_id:
context, user_query = await process_user_query(current_document_id)
if context and user_query:
logger.info(f"Processing follow-up question for document_id: {current_document_id}")
task = asyncio.create_task(query_openai_with_context(context, user_query))
user_session.set('streaming_task', task)
await task
else:
logger.warning("Context or user query not found for follow-up question.")
else:
logger.warning("No current document ID found for follow-up question.")
@cl.action_callback("ask_new_question")
async def handle_new_question(action):
"""Handle new question action."""
logger.info("New question about the same paper button clicked.")
current_document_id = user_session.get('current_document_id')
if current_document_id:
logger.info(f"Asking new question for document_id: {current_document_id}")
await ask_user_question(current_document_id)
else:
logger.warning("No current document ID found for new question.")
@cl.action_callback("ask_about_new_paper")
async def handle_new_paper(action):
"""Handle new paper action."""
logger.info("New paper button clicked.")
await ask_initial_query(initial=False)
if __name__ == "__main__":
asyncio.run(main())