Skip to content

Commit

Permalink
refactor: decouple the message handler
Browse files Browse the repository at this point in the history
  • Loading branch information
thehapyone committed Oct 20, 2024
1 parent 9d2c117 commit cf2fa36
Showing 1 changed file with 74 additions and 45 deletions.
119 changes: 74 additions & 45 deletions sage/sources/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,52 +126,56 @@ def _get_starter_source_label(message: str) -> Tuple[str, str]:
logger.warning("Error extracting label from starter %s", error)
return message, "none"

@cl.on_message
async def on_message(self, message: cl.Message):
"""Function to react user's message request"""
if self.mode == "tool":
raise ValueError("Tool mode is not supported here")
async def _handle_home_command(
self, message: cl.Message, chat_profile: str
) -> None:
"""Handle the '/home' command."""
intro_message = self.get_intro_message(chat_profile)
retriever = await self._mode_handlers.handle_chat_only_mode(
intro_message, message.id
)
self._runnable_handlers.setup_runnable(retriever=retriever)

# Handle starter message only once
if (
cl.user_session.get("starter_message", True)
and cl.user_session.get("chat_profile") == "Chat Only"
):
async def _handle_starter_message_if_needed(self, message: cl.Message) -> None:
"""Handle the starter message if it's the first interaction."""
starter_message = cl.user_session.get("starter_message", True)
chat_profile = cl.user_session.get("chat_profile")

if starter_message and chat_profile == "Chat Only":
cl.user_session.set("starter_message", False)
chat_profile = cl.user_session.get("chat_profile")

if message.content == "/home":
retriever = await self._mode_handlers.handle_chat_only_mode(
self.get_intro_message(chat_profile), message.id
)
self._runnable_handlers.setup_runnable(retriever=retriever)
await self._handle_home_command(message, chat_profile)
return

message.content, source_label = self._get_starter_source_label(
message.content
)
content, source_label = self._get_starter_source_label(message.content)
message.content = content
await message.update()

# Now we should set the retriever for the other messages
retriever = await self._mode_handlers.handle_chat_only_mode(
"", source_label=source_label
"",
source_label=source_label,
)
self._runnable_handlers.setup_runnable(retriever=retriever)

runnable: RunnableSequence = cl.user_session.get("runnable")
memory = get_memory(self.mode, cl.user_session)

msg = cl.Message(content="")
## Process and prepare image data
def _build_query(self, message: cl.Message) -> dict:
"""Build the query dictionary for the model inputs the message."""
images = [file for file in message.elements if file.mime.startswith("image")]
return {"question": message.content, "image_data": images if images else None}

query = {"question": message.content, "image_data": images if images else None}

_sources = None
_answer = None
text_elements: list[cl.Text] = []
async def _stream_answer(self, answer: str, output_message: cl.Message) -> None:
"""Stream the answer tokens to the user."""
await output_message.stream_token(answer)

async def _process_query(
self, runnable: RunnableSequence, query: dict
) -> cl.Message:
"""Process the query through the runnable and stream the answer."""
run_name = getattr(runnable, "config", {}).get("run_name", "")
answer = None
sources = None

final_message = cl.Message(content="")

async for chunk in runnable.astream(
query,
Expand All @@ -197,19 +201,21 @@ async def on_message(self, message: cl.Message):
],
),
):
_answer = chunk.get("answer")
if _answer:
await msg.stream_token(_answer)

if chunk.get("sources"):
_sources = chunk.get("sources")

# process sources
if _sources:
for source_doc in _sources:
if chunk_answer := chunk.get("answer"):
answer = chunk_answer

Check failure on line 205 in sage/sources/qa.py

View workflow job for this annotation

GitHub Actions / Code Style

Ruff (F841)

sage/sources/qa.py:205:17: F841 Local variable `answer` is assigned to but never used
# Stream the answer token by token
await self._stream_answer(chunk_answer, final_message)

if chunk_sources := chunk.get("sources"):
sources = chunk_sources

# process sources if available
source_elements: list[cl.Text] = []
if sources:
for source_doc in sources:
source_name = f"[{source_doc['id']}]"
source_content = source_doc["content"] + "\n" + source_doc["source"]
text_elements.append(
source_elements.append(
cl.Text(
content=source_content,
name=source_name,
Expand All @@ -218,12 +224,35 @@ async def on_message(self, message: cl.Message):
)
)

msg.elements = text_elements
await msg.send()
final_message.elements = source_elements

return final_message

def _update_memory(
self, user_message: cl.Message, ai_message: cl.Message
) -> None:
"""Update the conversation memory with the latest messages."""
memory = get_memory(self.mode, cl.user_session)
memory.chat_memory.add_ai_message(ai_message.content)
memory.chat_memory.add_user_message(user_message.content)

@cl.on_message
async def on_message(self, user_message: cl.Message) -> None:
"""Handle and respond to a user's message request."""
if self.mode == "tool":
raise ValueError("Tool mode is not supported here")

await self._handle_starter_message_if_needed(user_message)

runnable = cl.user_session.get("runnable")

## Build the input query
query = self._build_query(user_message)
## Process input query
assistant_message = await self._process_query(runnable, query)
await assistant_message.send()
# Save memory
memory.chat_memory.add_ai_message(msg.content)
memory.chat_memory.add_user_message(message.content)
self._update_memory(user_message, assistant_message)

async def _run(self, query: str) -> str:
"""Answer the question in the query"""
Expand Down

0 comments on commit cf2fa36

Please sign in to comment.