diff --git a/app.py b/app.py index a62fd62..cc55b69 100644 --- a/app.py +++ b/app.py @@ -2,7 +2,8 @@ import uuid import streamlit as st import datasets -from langchain_huggingface import HuggingFaceEndpointEmbeddings, HuggingFaceEndpoint +from langchain_huggingface import HuggingFaceEndpointEmbeddings, ChatHuggingFace +from langchain_community.llms import HuggingFaceEndpoint from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough from langchain.agents import create_react_agent, AgentExecutor @@ -49,8 +50,8 @@ def setup_chroma_embedding_function(): return chroma_embedding_function # Set up HuggingFaceEndpoint model -def setup_huggingface_endpoint(): - model = HuggingFaceEndpoint( +def setup_huggingface_endpoint(model_id): + llm = HuggingFaceEndpoint( endpoint_url="http://{host}:{port}".format( host=os.getenv("TGI_HOST", "localhost"), port=os.getenv("TGI_PORT", "8080") ), @@ -61,6 +62,10 @@ def setup_huggingface_endpoint(): "{your_token}".format(your_token=os.getenv("STOP_TOKEN", "<|end_of_text|>")), ], ) + + model = ChatHuggingFace(llm=llm, + model_id=model_id) + return model def setup_portkey_integrated_model(): @@ -220,6 +225,7 @@ def setup_tools(_model, _client, _chroma_embedding_function, _embedder): # chroma_embedding_function=_chroma_embedding_function, # embedder=_embedder, #) + if os.getenv("USE_RERANKER", "False") == "True": retriever = create_reranker_retriever( name="slack_conversations_retriever", @@ -245,7 +251,7 @@ def setup_tools(_model, _client, _chroma_embedding_function, _embedder): @st.cache_resource def setup_agent(_model, _prompt, _client, _chroma_embedding_function, _embedder): tools = setup_tools(_model, _client, _chroma_embedding_function, _embedder) - agent = create_react_agent(llm=_model, prompt=_prompt, tools=tools,) + agent = create_react_agent(llm=_model, prompt=_prompt, tools=tools, ) agent_executor = AgentExecutor( agent=agent, verbose=True, tools=tools, handle_parsing_errors=True ) @@ -258,7 +264,7 @@ def main(): if os.getenv("ENABLE_PORTKEY", "False") == "True": model = setup_portkey_integrated_model() else: - model = setup_huggingface_endpoint() + model = setup_huggingface_endpoint(model_id="qwen/Qwen2-7B-Instruct") embedder = setup_huggingface_embeddings() agent_executor = setup_agent(