Skip to content

Commit

Permalink
Multi modality (#84)
Browse files Browse the repository at this point in the history
* feat: adds support for muilt-modality in sage chart modes

* refactor: decouple the message handler

* fix: starter issue
  • Loading branch information
thehapyone authored Oct 21, 2024
1 parent c3abc29 commit 0634c36
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 58 deletions.
4 changes: 2 additions & 2 deletions .chainlit/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ edit_message = true

# Authorize users to spontaneously upload files with messages
[features.spontaneous_file_upload]
enabled = false
accept = ["*/*"]
enabled = true
accept = ["image/jpeg", "image/png", "image/gif"]
max_files = 20
max_size_mb = 500

Expand Down
110 changes: 99 additions & 11 deletions sage/models/chat_prompt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
from dataclasses import dataclass

from langchain.prompts import ChatPromptTemplate, PromptTemplate
Expand Down Expand Up @@ -42,7 +43,7 @@ class ChatPrompt:
Standalone question::
"""

qa_template_chat: str = """
qa_system_prompt: str = """
As an AI assistant named Sage, your mandate is to provide accurate and impartial answers to questions while engaging in normal conversation.
You must differentiate between questions that require answers and standard user chat conversations. In standard conversation, especially when discussing your own nature as an AI, footnotes or sources are not required, as the information is based on your programmed capabilities and functions. Your responses should adhere to a journalistic style, characterized by neutrality and reliance on factual, verifiable information.
Expand All @@ -59,6 +60,17 @@ class ChatPrompt:
- Avoid adding any sources in the footnotes when the response does not reference specific context.
- Citations must not be inserted anywhere in the answer, only listed in a 'Footnotes' section at the end of the response.
REMEMBER: No in-line citations and no citation repetition. State sources in the 'Footnotes' section. For standard conversation and questions about Sage's nature, no footnotes are required. Include footnotes only when they are directly relevant to the provided answer.
Footnotes:
[1] - Brief summary of the first source. (Less than 10 words)
[2] - Brief summary of the second source.
...continue for additional sources, only if relevant and necessary.
"""

qa_user_prompt: str = """
Question: {question}
<context>
{context}
</context>
Expand All @@ -67,21 +79,11 @@ class ChatPrompt:
<chat_history>
{chat_history}
<chat_history/>
Question: {question}
REMEMBER: No in-line citations and no citation repetition. State sources in the 'Footnotes' section. For standard conversation and questions about Sage's nature, no footnotes are required. Include footnotes only when they are directly relevant to the provided answer.
Footnotes:
[1] - Brief summary of the first source. (Less than 10 words)
[2] - Brief summary of the second source.
...continue for additional sources, only if relevant and necessary.
"""

# The prompt template for the condense question chain
condense_prompt = PromptTemplate.from_template(condensed_template)

qa_prompt = ChatPromptTemplate.from_template(qa_template_chat)
"""The prompt template for the chat complete chain"""

def tool_description(self, source_repr: str) -> str:
Expand Down Expand Up @@ -142,3 +144,89 @@ def generate_welcome_message(
"To get started, simply select an option below; then begin typing your query or ask for help to see what I can do."
)
return message.strip()

def encode_image(self, image_path: str) -> str:
"""
Encodes an image file into a base64 string.
This function reads an image file from the provided file path,
encodes its binary data into a base64 format, and returns the
encoded string.
Args:
image_path (str): The path to the image file to be encoded.
Returns:
str: The base64 encoded string representation of the image.
"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode()

def create_qa_prompt(
self, system_prompt, user_prompt, additional_user_prompts: list = None
):
"""
Creates a structured QA prompt template for the chat system.
This method returns a `ChatPromptTemplate` object by combining the
provided system-level prompt and user-level prompt messages. Additionally,
it can incorporate a list of extra user prompts, such as images or other media.
Args:
system_prompt (str): The primary prompt intended for the system's context.
user_prompt (str): The main prompt intended for the user's input.
additional_user_prompts (list, optional): A list of additional user prompts. Each entry
in the list should be a dictionary specifying
the type and content of the prompt.
Returns:
ChatPromptTemplate: A template object containing the fully structured prompt,
ready to be used in the chat system.
"""
user_messages = [{"type": "text", "text": user_prompt}]
if additional_user_prompts:
user_messages.extend(additional_user_prompts)
return ChatPromptTemplate.from_messages(
[
("system", system_prompt),
("user", user_messages),
]
)

def modality_prompt_router(self, x: dict) -> ChatPromptTemplate:
"""
Routes to the appropriate QA prompt template based on the presence of image data.
This function checks if the provided dictionary `x` contains image data and returns
the corresponding QA prompt template. If no image data is present, it returns the
standard QA prompt (`qa_prompt`). If image data is present, it processes each image
by encoding it to base64 and appending it to the additional user prompts, then
creates a new QA prompt (`qa_prompt_modality`) with the included image information.
Args:
x (dict): A dictionary that may contain image data with keys as follows:
- "image_data": A list of dictionaries, each containing:
- "mime": The MIME type of the image (e.g., 'image/jpeg').
- "path": The file path to the image to be encoded.
"""
images = x.get("image_data")

if not images:
# Standard Prompt Template
return self.create_qa_prompt(self.qa_system_prompt, self.qa_user_prompt)

images_content = [
{
"type": "image_url",
"image_url": {
"url": f"data:{image.mime};base64,{self.encode_image(image.path)}"
},
}
for image in images
]

# Prompt Template for Multi-Modality (Includes Image Data)
return self.create_qa_prompt(
system_prompt=self.qa_system_prompt,
user_prompt=self.qa_user_prompt,
additional_user_prompts=images_content,
)
115 changes: 73 additions & 42 deletions sage/sources/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,49 +126,57 @@ def _get_starter_source_label(message: str) -> Tuple[str, str]:
logger.warning("Error extracting label from starter %s", error)
return message, "none"

@cl.on_message
async def on_message(self, message: cl.Message):
"""Function to react user's message request"""
if self.mode == "tool":
raise ValueError("Tool mode is not supported here")
async def _handle_home_command(
self, message: cl.Message, chat_profile: str
) -> None:
"""Handle the '/home' command."""
intro_message = self.get_intro_message(chat_profile)
retriever = await self._mode_handlers.handle_chat_only_mode(
intro_message, message.id
)
self._runnable_handlers.setup_runnable(retriever=retriever)

# Handle starter message only once
if (
cl.user_session.get("starter_message", True)
and cl.user_session.get("chat_profile") == "Chat Only"
):
async def _handle_starter_message_if_needed(
self, message: cl.Message
) -> bool | None:
"""Handle the starter message if it's the first interaction."""
starter_message = cl.user_session.get("starter_message", True)
chat_profile = cl.user_session.get("chat_profile")

if starter_message and chat_profile == "Chat Only":
cl.user_session.set("starter_message", False)
chat_profile = cl.user_session.get("chat_profile")

if message.content == "/home":
retriever = await self._mode_handlers.handle_chat_only_mode(
self.get_intro_message(chat_profile), message.id
)
self._runnable_handlers.setup_runnable(retriever=retriever)
return
await self._handle_home_command(message, chat_profile)
return True

message.content, source_label = self._get_starter_source_label(
message.content
)
content, source_label = self._get_starter_source_label(message.content)
message.content = content
await message.update()

# Now we should set the retriever for the other messages
retriever = await self._mode_handlers.handle_chat_only_mode(
"", source_label=source_label
"",
source_label=source_label,
)
self._runnable_handlers.setup_runnable(retriever=retriever)

runnable: RunnableSequence = cl.user_session.get("runnable")
memory = get_memory(self.mode, cl.user_session)

msg = cl.Message(content="")
def _build_query(self, message: cl.Message) -> dict:
"""Build the query dictionary for the model inputs the message."""
images = [file for file in message.elements if file.mime.startswith("image")]
return {"question": message.content, "image_data": images if images else None}

query = {"question": message.content}
_sources = None
_answer = None
text_elements: list[cl.Text] = []
async def _stream_answer(self, answer: str, output_message: cl.Message) -> None:
"""Stream the answer tokens to the user."""
await output_message.stream_token(answer)

async def _process_query(
self, runnable: RunnableSequence, query: dict
) -> cl.Message:
"""Process the query through the runnable and stream the answer."""
run_name = getattr(runnable, "config", {}).get("run_name", "")
sources = None

final_message = cl.Message(content="")

async for chunk in runnable.astream(
query,
Expand All @@ -194,19 +202,20 @@ async def on_message(self, message: cl.Message):
],
),
):
_answer = chunk.get("answer")
if _answer:
await msg.stream_token(_answer)
if chunk_answer := chunk.get("answer"):
# Stream the answer token by token
await self._stream_answer(chunk_answer, final_message)

if chunk.get("sources"):
_sources = chunk.get("sources")
if chunk_sources := chunk.get("sources"):
sources = chunk_sources

# process sources
if _sources:
for source_doc in _sources:
# process sources if available
source_elements: list[cl.Text] = []
if sources:
for source_doc in sources:
source_name = f"[{source_doc['id']}]"
source_content = source_doc["content"] + "\n" + source_doc["source"]
text_elements.append(
source_elements.append(
cl.Text(
content=source_content,
name=source_name,
Expand All @@ -215,12 +224,34 @@ async def on_message(self, message: cl.Message):
)
)

msg.elements = text_elements
await msg.send()
final_message.elements = source_elements

return final_message

def _update_memory(self, user_message: cl.Message, ai_message: cl.Message) -> None:
"""Update the conversation memory with the latest messages."""
memory = get_memory(self.mode, cl.user_session)
memory.chat_memory.add_ai_message(ai_message.content)
memory.chat_memory.add_user_message(user_message.content)

@cl.on_message
async def on_message(self, user_message: cl.Message) -> None:
"""Handle and respond to a user's message request."""
if self.mode == "tool":
raise ValueError("Tool mode is not supported here")

if await self._handle_starter_message_if_needed(user_message):
return

runnable = cl.user_session.get("runnable")

## Build the input query
query = self._build_query(user_message)
## Process input query
assistant_message = await self._process_query(runnable, query)
await assistant_message.send()
# Save memory
memory.chat_memory.add_ai_message(msg.content)
memory.chat_memory.add_user_message(message.content)
self._update_memory(user_message, assistant_message)

async def _run(self, query: str) -> str:
"""Answer the question in the query"""
Expand Down
13 changes: 10 additions & 3 deletions sage/sources/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,15 @@ def _create_chat_runnable(
self, _inputs, _retrieved_docs, _context
) -> RunnableSequence:
"""Implementation for creating chat runnable"""
qa_prompt = ChatPrompt().qa_prompt

# construct the question and answer model
qa_answer = RunnableMap(
answer=_context | qa_prompt | self.base_model | StrOutputParser(),
answer=_context
| RunnableLambda(ChatPrompt().modality_prompt_router).with_config(
run_name="Modality-Router"
)
| self.base_model
| StrOutputParser(),
sources=lambda x: format_sources(x["docs"]),
).with_config(run_name="Sage Assistant")

Expand Down Expand Up @@ -100,20 +104,23 @@ def standalone_chain_router(x: dict):
"question": lambda x: x["question"],
"chat_history": chat_history_loader,
}
| RunnableLambda(standalone_chain_router).with_config(run_name="Condenser")
| RunnableLambda(standalone_chain_router).with_config(run_name="Condenser"),
image_data=itemgetter("image_data"),
)

# retrieve the documents
_retrieved_docs = RunnableMap(
docs=itemgetter("standalone") | retriever,
question=itemgetter("standalone"),
image_data=itemgetter("image_data"),
).with_config(run_name="Source Retriever")

# rconstruct the context inputs
_context = RunnableMap(
context=lambda x: format_docs(x["docs"]),
chat_history=chat_history_loader,
question=itemgetter("question"),
image_data=itemgetter("image_data"),
)
_runnable = self._create_chat_runnable(_inputs, _retrieved_docs, _context)

Expand Down

0 comments on commit 0634c36

Please sign in to comment.