From 5d0228d4a136990f1fc1d7869904b692c3ce17c0 Mon Sep 17 00:00:00 2001 From: hsaeed3 Date: Mon, 30 Sep 2024 15:19:27 -0700 Subject: [PATCH] basemodel hotfix && ruff --- pyproject.toml | 2 +- zyx/__init__.py | 44 +- zyx/client.py | 449 +++++++++--------- zyx/lib/router/agents/__init__.py | 10 +- zyx/lib/router/agents/__init__.pyi | 11 +- zyx/lib/router/data/__init__.pyi | 1 - zyx/lib/router/ext/__init__.py | 19 +- zyx/lib/router/ext/__init__.pyi | 16 +- zyx/lib/router/llm/__init__.py | 2 +- zyx/lib/router/llm/__init__.pyi | 4 +- zyx/lib/types/base_model.py | 73 ++- zyx/lib/types/document.py | 41 +- zyx/lib/utils/logger.py | 12 +- .../completions/agents/conversation.py | 122 +++-- zyx/resources/completions/agents/judge.py | 148 ++++-- zyx/resources/completions/agents/plan.py | 37 +- zyx/resources/completions/agents/query.py | 33 +- zyx/resources/completions/agents/scrape.py | 33 +- zyx/resources/completions/agents/solve.py | 51 +- zyx/resources/completions/base/classify.py | 82 ++-- zyx/resources/completions/base/code.py | 34 +- zyx/resources/completions/base/extract.py | 55 +-- zyx/resources/completions/base/function.py | 76 +-- zyx/resources/completions/base/generate.py | 76 +-- .../completions/base/system_prompt.py | 65 ++- zyx/resources/data/chunk.py | 2 +- zyx/resources/data/reader.py | 2 +- zyx/resources/ext/app.py | 19 +- zyx/resources/stores/memory.py | 70 ++- 29 files changed, 857 insertions(+), 732 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a2d96bb..a2b7b55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "zyx" -version = "1.0.2" +version = "1.0.3" description = "A hyper-fast, fun, quality-of-life focused & genuinely useful LLM toolkit. Inspired by Marvin-AI. Built on LiteLLM, Instructor & Qdrant." authors = ["Hammad Saeed "] readme = "readme.md" diff --git a/zyx/__init__.py b/zyx/__init__.py index 50afdfd..8b76e58 100644 --- a/zyx/__init__.py +++ b/zyx/__init__.py @@ -1,30 +1,24 @@ __all__ = [ # utils "logger", - # modules "llm", "agents", "data", "tools", - # Core (Types) "BaseModel", "Field", "Document", - # data - core "Memory", - # data - tools "embeddings", "chunk", "read", - # llm - core "Client", "completion", - # llm - base functions "classify", "code", @@ -32,7 +26,6 @@ "function", "generate", "system_prompt", - # llm - agentic reasoning "Character", "conversation", @@ -41,12 +34,10 @@ "query", "scrape", "solve", - # ext - multimodal "image", "audio", "transcribe", - # ext - app "app", ] @@ -56,33 +47,34 @@ from .lib.utils.logger import logger # modules -from .lib.router import ( - llm, agents, data -) +from .lib.router import llm, agents, data from .resources import tools # data -from .lib.router.data import ( - Memory, - Document, - embeddings, - chunk, read -) +from .lib.router.data import Memory, Document, embeddings, chunk, read # llm - base & core from .lib.router.llm import ( - Client, completion, - classify, code, extract, function, generate, system_prompt + Client, + completion, + classify, + code, + extract, + function, + generate, + system_prompt, ) # llm - agents from .lib.router.agents import ( - Character, conversation, - judge, plan, query, scrape, solve + Character, + conversation, + judge, + plan, + query, + scrape, + solve, ) # ext -from .lib.router.ext import ( - BaseModel, Field, - app, image, audio, transcribe -) \ No newline at end of file +from .lib.router.ext import BaseModel, Field, app, image, audio, transcribe diff --git a/zyx/client.py b/zyx/client.py index acaa282..76dc2ef 100644 --- a/zyx/client.py +++ b/zyx/client.py @@ -21,6 +21,7 @@ ## -- Instructor Configuration -- ## + ## -- Instructor Mode -- ## ## This was directly ported from instructor ## https://github.com/jxnl/instructor/ @@ -142,7 +143,6 @@ class CompletionArgs(BaseModel): class Client: - """ Base class for all LLM completions in the zyx library. Runs using either the OpenAI or LiteLLM client libraries. @@ -197,9 +197,7 @@ def recommend_client_by_model( return client, model, base_url, api_key @staticmethod - def format_to_openai_tools( - tools: List[ToolType] - ) -> List[Tool]: + def format_to_openai_tools(tools: List[ToolType]) -> List[Tool]: """Converts the tools to a list of dictionaries. Args: @@ -238,7 +236,6 @@ def get_tool_dict(tools: List[Tool]) -> List[Dict[str, Any]]: tool_dict.append(tool.openai_tool) return tool_dict - @staticmethod def format_messages( @@ -262,7 +259,9 @@ def format_messages( print(f"Converting string to message format.") return [{"role": type, "content": messages}] - elif isinstance(messages, list) and all(isinstance(m, dict) for m in messages): + elif isinstance(messages, list) and all( + isinstance(m, dict) for m in messages + ): if verbose: print(f"Messages are in the correct format.") @@ -273,7 +272,6 @@ def format_messages( print(f"Error formatting messages: {e}") return [] - @staticmethod def does_system_prompt_exist(messages: list[dict]) -> bool: """Simple boolean check to see if a system prompt exists in the messages. @@ -287,7 +285,6 @@ def does_system_prompt_exist(messages: list[dict]) -> bool: return any(message.get("role") == "system" for message in messages) - @staticmethod def swap_system_prompt( system_prompt: dict = None, messages: Union[str, list[dict[str, str]]] = None @@ -321,12 +318,14 @@ def swap_system_prompt( break # Remove any duplicate system messages - while len([message for message in messages if message.get("role") == "system"]) > 1: + while ( + len([message for message in messages if message.get("role") == "system"]) + > 1 + ): messages.pop() return messages - @staticmethod def repair_messages( messages: list[dict], verbose: Optional[bool] = False @@ -371,7 +370,6 @@ def repair_messages( return messages - @staticmethod def add_messages( inputs: Union[str, list[dict], dict] = None, @@ -394,7 +392,9 @@ def add_messages( """ if isinstance(inputs, str): - formatted_message = Client.format_messages(messages=inputs, verbose=verbose, type=type) + formatted_message = Client.format_messages( + messages=inputs, verbose=verbose, type=type + ) messages.extend(formatted_message) @@ -410,21 +410,19 @@ def add_messages( print(f"Skipping invalid message format: {item}") return Client.repair_messages(messages, verbose) - def __init__( - self, - - api_key : Optional[str] = None, - base_url : Optional[str] = None, - organization : Optional[str] = None, - provider : Optional[Literal["openai", "litellm"]] = None, - verbose : bool = False + self, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + organization: Optional[str] = None, + provider: Optional[Literal["openai", "litellm"]] = None, + verbose: bool = False, ): """Initializes the completion client with the specified parameters. - + Example: - + ```python client = Client( api_key = "sk-...", @@ -445,14 +443,14 @@ def __init__( Returns: None """ - + self.clients = ClientProviders() self.config = ClientConfig( api_key=api_key, base_url=base_url, organization=organization, - verbose=verbose + verbose=verbose, ) self.provider = provider @@ -460,19 +458,18 @@ def __init__( if self.provider: self.clients.client = self.__init_client__() - def __init_client__(self): """ Initializes the specified client library. """ - + if self.provider == "openai": from openai import OpenAI client = OpenAI( - api_key = self.config.api_key, - base_url = self.config.base_url, - organization = self.config.organization + api_key=self.config.api_key, + base_url=self.config.base_url, + organization=self.config.organization, ) elif self.provider == "litellm": @@ -482,16 +479,15 @@ def __init_client__(self): litellm.drop_params = True client = LiteLLM( - api_key = self.config.api_key, - base_url = self.config.base_url, - organization = self.config.organization + api_key=self.config.api_key, + base_url=self.config.base_url, + organization=self.config.organization, ) if self.config.verbose: logger.info(f"Initialized {self.provider} client") return client - def __patch_client__(self): """ @@ -506,7 +502,7 @@ def __patch_client__(self): from instructor import from_openai patched_client = from_openai(self.clients.client) - + else: from instructor import patch @@ -517,10 +513,7 @@ def __patch_client__(self): return patched_client - - def chat_completion( - self, args : CompletionArgs - ): + def chat_completion(self, args: CompletionArgs): """ Runs a standard chat completion. @@ -530,17 +523,19 @@ def chat_completion( Returns: CompletionResponse: The response to the completion. """ - + exclude_params = {"response_model"} try: if args.tools is None: exclude_params.update({"tools", "parallel_tool_calls", "tool_choice"}) - + # O1 Specific Handler # Will be removed once OpenAI supports all O1 Parameters if args.model.startswith("o1-"): - logger.warning("OpenAI O1- model detected. Removing all non-supported parameters.") + logger.warning( + "OpenAI O1- model detected. Removing all non-supported parameters." + ) exclude_params.update( { "max_tokens", @@ -559,15 +554,15 @@ def chat_completion( if self.config.verbose: logger.info(f"Streaming completion... with {args.model} model") - stream = self.clients.client.chat.completions.create( - **args.model_dump(exclude = exclude_params) + stream = self.clients.client.chat.completions.create( + **args.model_dump(exclude=exclude_params) ) return ( chunk.choices[0].delta.content for chunk in stream if chunk.choices[0].delta.content ) - + else: if self.config.verbose: logger.info(f"Generating completion... with {args.model} model") @@ -575,75 +570,73 @@ def chat_completion( exclude_params.add("stream") return self.clients.client.chat.completions.create( - **args.model_dump(exclude = exclude_params) + **args.model_dump(exclude=exclude_params) ) - + except Exception as e: logger.error(f"Error in chat_completion: {e}") raise - - def instructor_completion( - self, args : CompletionArgs - ): + def instructor_completion(self, args: CompletionArgs): """Runs an Instructor completion - + Args: args: CompletionArgs: The arguments to the completion. Returns: CompletionResponse: The response to the completion. """ - + try: if not self.clients.instructor: self.clients.instructor = self.__patch_client__() - if args.tools is None: - exclude_params = ({"tools", "parallel_tool_calls", "tool_choice"}) + exclude_params = {"tools", "parallel_tool_calls", "tool_choice"} if args.model.startswith("o1-"): - logger.warning("OpenAI O1- model detected. Removing all non-supported parameters.") - exclude_params = ( - { - "max_tokens", - "temperature", - "top_p", - "frequency_penalty", - "presence_penalty", - "tools", - "parallel_tool_calls", - "tool_choice", - "stop", - } + logger.warning( + "OpenAI O1- model detected. Removing all non-supported parameters." ) + exclude_params = { + "max_tokens", + "temperature", + "top_p", + "frequency_penalty", + "presence_penalty", + "tools", + "parallel_tool_calls", + "tool_choice", + "stop", + } if args.stream: if self.config.verbose: - logger.info(f"Streaming Instructor completion... with {args.model} model") + logger.info( + f"Streaming Instructor completion... with {args.model} model" + ) exclude_params.add("stream") return self.clients.instructor.chat.completions.create_partial( - **args.model_dump(exclude = exclude_params) + **args.model_dump(exclude=exclude_params) ) - - else: + else: if self.config.verbose: - logger.info(f"Generating Instructor completion... with {args.model} model") + logger.info( + f"Generating Instructor completion... with {args.model} model" + ) exclude_params.add("stream") return self.clients.instructor.chat.completions.create( - **args.model_dump(exclude = exclude_params) + **args.model_dump(exclude=exclude_params) ) except Exception as e: logger.error(f"Error in instructor_completion: {e}") raise - def execute_tool_call( self, @@ -706,25 +699,24 @@ def execute_tool_call( return response - def run_completion( - self, - messages: Union[str, list[dict]] = None, - model: str = "gpt-4o", - response_model: Optional[Type[BaseModel]] = None, - mode: Optional[InstructorMode] = "tool_call", - max_retries: Optional[int] = 3, - run_tools: Optional[bool] = True, - tools: Optional[List[ToolType]] = None, - parallel_tool_calls: Optional[bool] = False, - tool_choice: Optional[Literal["none", "auto", "required"]] = "auto", - max_tokens: Optional[int] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - stop: Optional[List[str]] = None, - stream: Optional[bool] = False, + self, + messages: Union[str, list[dict]] = None, + model: str = "gpt-4o", + response_model: Optional[Type[BaseModel]] = None, + mode: Optional[InstructorMode] = "tool_call", + max_retries: Optional[int] = 3, + run_tools: Optional[bool] = True, + tools: Optional[List[ToolType]] = None, + parallel_tool_calls: Optional[bool] = False, + tool_choice: Optional[Literal["none", "auto", "required"]] = "auto", + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + stop: Optional[List[str]] = None, + stream: Optional[bool] = False, ): """ Runs a completion with the specified arguments. @@ -735,7 +727,7 @@ def run_completion( messages = "Hello!", model = "gpt-4o-mini ) - ``` + ``` Args: messages: Union[str, list[dict]]: The messages to complete. @@ -758,34 +750,32 @@ def run_completion( Returns: CompletionResponse: The completion response. """ - - + formatted_tools = None if tools: formatted_tools = self.format_to_openai_tools(tools) args = CompletionArgs( - messages = self.format_messages(messages), - model = model, - response_model = response_model, - tools = self.get_tool_dict(formatted_tools) if formatted_tools else None, - parallel_tool_calls = parallel_tool_calls, - tool_choice = tool_choice, - max_tokens = max_tokens, - temperature = temperature, - top_p = top_p, - frequency_penalty = frequency_penalty, - presence_penalty = presence_penalty, - stop = stop, - stream = stream + messages=self.format_messages(messages), + model=model, + response_model=response_model, + tools=self.get_tool_dict(formatted_tools) if formatted_tools else None, + parallel_tool_calls=parallel_tool_calls, + tool_choice=tool_choice, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + stop=stop, + stream=stream, ) if not response_model: if not run_tools or not formatted_tools: return self.chat_completion(args) - - else: + else: args.stream = False base_response = self.chat_completion(args) @@ -800,9 +790,8 @@ def run_completion( return self.chat_completion(args) else: return base_response - - else: + else: if formatted_tools: original_args = args @@ -818,10 +807,9 @@ def run_completion( return self.instructor_completion(args) else: return self.instructor_completion(original_args) - + else: return self.instructor_completion(args) - def completion( self, @@ -845,7 +833,7 @@ def completion( stop: Optional[List[str]] = None, stream: Optional[bool] = False, provider: Optional[Literal["openai", "litellm"]] = None, - verbose: bool = False + verbose: bool = False, ): """ Runs a completion with the specified arguments. @@ -884,7 +872,12 @@ def completion( Returns: CompletionResponse: The completion response. """ - recommended_provider, recommended_model, recommended_base_url, recommended_api_key = self.recommend_client_by_model(model, base_url, api_key) + ( + recommended_provider, + recommended_model, + recommended_base_url, + recommended_api_key, + ) = self.recommend_client_by_model(model, base_url, api_key) # Reinitialize client only if the recommended provider is different if recommended_provider != self.provider: @@ -893,7 +886,7 @@ def completion( base_url=recommended_base_url or base_url or self.config.base_url, organization=organization or self.config.organization, provider=recommended_provider, - verbose=verbose or self.config.verbose + verbose=verbose or self.config.verbose, ) # Update model if it was changed by recommend_client_by_model @@ -906,7 +899,9 @@ def completion( mode = get_mode(mode) if model.startswith("o1-"): - logger.warning("OpenAI O1- model detected. Using JSON_O1 Instructor Mode.") + logger.warning( + "OpenAI O1- model detected. Using JSON_O1 Instructor Mode." + ) mode = Mode.JSON_O1 self.clients.instructor.mode = mode @@ -927,33 +922,32 @@ def completion( frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, stop=stop, - stream=stream + stream=stream, ) - @staticmethod def _completion( - messages: Union[str, list[dict]] = None, - model: str = "gpt-4o", - api_key : Optional[str] = None, - base_url : Optional[str] = None, - organization : Optional[str] = None, - response_model: Optional[Type[BaseModel]] = None, - mode: Optional[InstructorMode] = "tool_call", - max_retries: Optional[int] = 3, - run_tools: Optional[bool] = True, - tools: Optional[List[ToolType]] = None, - parallel_tool_calls: Optional[bool] = False, - tool_choice: Optional[Literal["none", "auto", "required"]] = "auto", - max_tokens: Optional[int] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - stop: Optional[List[str]] = None, - stream: Optional[bool] = False, - provider : Optional[Literal["openai", "litellm"]] = None, - verbose : bool = False + messages: Union[str, list[dict]] = None, + model: str = "gpt-4o", + api_key: Optional[str] = None, + base_url: Optional[str] = None, + organization: Optional[str] = None, + response_model: Optional[Type[BaseModel]] = None, + mode: Optional[InstructorMode] = "tool_call", + max_retries: Optional[int] = 3, + run_tools: Optional[bool] = True, + tools: Optional[List[ToolType]] = None, + parallel_tool_calls: Optional[bool] = False, + tool_choice: Optional[Literal["none", "auto", "required"]] = "auto", + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + stop: Optional[List[str]] = None, + stream: Optional[bool] = False, + provider: Optional[Literal["openai", "litellm"]] = None, + verbose: bool = False, ): """ Runs a completion with the specified arguments. @@ -964,8 +958,8 @@ def _completion( messages = "Hello!", model = "gpt-4o-mini ) - ``` - + ``` + Args: messages: Union[str, list[dict]]: The messages to complete. model: str: The model to use. @@ -992,28 +986,27 @@ def _completion( Returns: CompletionResponse: The completion response. """ - + if provider: client = Client( - api_key = api_key, - base_url = base_url, - organization = organization, - provider = provider, - verbose = verbose + api_key=api_key, + base_url=base_url, + organization=organization, + provider=provider, + verbose=verbose, ) else: provider, model, base_url, api_key = Client.recommend_client_by_model(model) client = Client( - api_key = api_key, - base_url = base_url, - organization = organization, - provider = provider, - verbose = verbose + api_key=api_key, + base_url=base_url, + organization=organization, + provider=provider, + verbose=verbose, ) - if response_model: if not client.clients.instructor: client.clients.instructor = client.__patch_client__() @@ -1021,58 +1014,58 @@ def _completion( mode = get_mode(mode) if model.startswith("o1-"): - logger.warning("OpenAI O1- model detected. Using JSON_O1 Instructor Mode.") + logger.warning( + "OpenAI O1- model detected. Using JSON_O1 Instructor Mode." + ) mode = Mode.JSON_O1 client.clients.instructor.mode = mode - return client.run_completion( - messages = messages, - model = model, - response_model = response_model, - mode = mode, - max_retries = max_retries, - run_tools = run_tools, - tools = tools, - parallel_tool_calls = parallel_tool_calls, - tool_choice = tool_choice, - max_tokens = max_tokens, - temperature = temperature, - top_p = top_p, - frequency_penalty = frequency_penalty, - presence_penalty = presence_penalty, - stop = stop, - stream = stream + messages=messages, + model=model, + response_model=response_model, + mode=mode, + max_retries=max_retries, + run_tools=run_tools, + tools=tools, + parallel_tool_calls=parallel_tool_calls, + tool_choice=tool_choice, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + stop=stop, + stream=stream, ) - + def completion( - messages: Union[str, list[dict]] = None, - model: str = "gpt-4o", - api_key : Optional[str] = None, - base_url : Optional[str] = None, - organization : Optional[str] = None, - response_model: Optional[Type[BaseModel]] = None, - mode: Optional[InstructorMode] = "tool_call", - max_retries: Optional[int] = 3, - run_tools: Optional[bool] = True, - tools: Optional[List[ToolType]] = None, - parallel_tool_calls: Optional[bool] = False, - tool_choice: Optional[Literal["none", "auto", "required"]] = "auto", - max_tokens: Optional[int] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - stop: Optional[List[str]] = None, - stream: Optional[bool] = False, - client : Optional[Literal["openai", "litellm"]] = None, - verbose : bool = False + messages: Union[str, list[dict]] = None, + model: str = "gpt-4o", + api_key: Optional[str] = None, + base_url: Optional[str] = None, + organization: Optional[str] = None, + response_model: Optional[Type[BaseModel]] = None, + mode: Optional[InstructorMode] = "tool_call", + max_retries: Optional[int] = 3, + run_tools: Optional[bool] = True, + tools: Optional[List[ToolType]] = None, + parallel_tool_calls: Optional[bool] = False, + tool_choice: Optional[Literal["none", "auto", "required"]] = "auto", + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + stop: Optional[List[str]] = None, + stream: Optional[bool] = False, + client: Optional[Literal["openai", "litellm"]] = None, + verbose: bool = False, ) -> CompletionResponse: - """Runs an LLM completion, with tools, streaming or Pydantic structured outputs. - + Example: ```python @@ -1115,50 +1108,46 @@ def completion( provider = client return Client._completion( - messages = messages, - model = model, - api_key = api_key, - base_url = base_url, - organization = organization, - response_model = response_model, - mode = mode, - max_retries = max_retries, - run_tools = run_tools, - tools = tools, - parallel_tool_calls = parallel_tool_calls, - tool_choice = tool_choice, - max_tokens = max_tokens, - temperature = temperature, - top_p = top_p, - frequency_penalty = frequency_penalty, - presence_penalty = presence_penalty, - stop = stop, - stream = stream, - provider = provider, - verbose = verbose + messages=messages, + model=model, + api_key=api_key, + base_url=base_url, + organization=organization, + response_model=response_model, + mode=mode, + max_retries=max_retries, + run_tools=run_tools, + tools=tools, + parallel_tool_calls=parallel_tool_calls, + tool_choice=tool_choice, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + stop=stop, + stream=stream, + provider=provider, + verbose=verbose, ) if __name__ == "__main__": class PersonModel(BaseModel): - secret_identity : str - name : str - age : int + secret_identity: str + name: str + age: int - def get_secret_identity(name : str): + def get_secret_identity(name: str): return "Batman" - print( - completion( - "Who is SpiderMan", verbose = True, response_model = PersonModel - ) - ) + print(completion("Who is SpiderMan", verbose=True, response_model=PersonModel)) print( completion( - messages = "Who is SpiderMan", - tools = [get_secret_identity], - verbose = True, + messages="Who is SpiderMan", + tools=[get_secret_identity], + verbose=True, ) ) diff --git a/zyx/lib/router/agents/__init__.py b/zyx/lib/router/agents/__init__.py index 926f67f..170cc60 100644 --- a/zyx/lib/router/agents/__init__.py +++ b/zyx/lib/router/agents/__init__.py @@ -1,12 +1,4 @@ -__all__ = [ - "Character", - "conversation", - "judge", - "plan", - "query", - "scrape", - "solve" -] +__all__ = ["Character", "conversation", "judge", "plan", "query", "scrape", "solve"] from .._router import router diff --git a/zyx/lib/router/agents/__init__.pyi b/zyx/lib/router/agents/__init__.pyi index 8f26430..ecd5359 100644 --- a/zyx/lib/router/agents/__init__.pyi +++ b/zyx/lib/router/agents/__init__.pyi @@ -1,13 +1,4 @@ -__all__ = [ - "Character", - "conversation", - "judge", - "plan", - "query", - "scrape", - "solve" -] - +__all__ = ["Character", "conversation", "judge", "plan", "query", "scrape", "solve"] from ....resources.completions.agents.conversation import conversation as conversation from ....resources.completions.agents.conversation import Character as Character diff --git a/zyx/lib/router/data/__init__.pyi b/zyx/lib/router/data/__init__.pyi index 7fb98ba..26fe03e 100644 --- a/zyx/lib/router/data/__init__.pyi +++ b/zyx/lib/router/data/__init__.pyi @@ -7,7 +7,6 @@ __all__ = [ "read", ] - from ....resources.stores.memory import Memory as Memory from ...types.document import Document as Document from litellm.main import embedding as embeddings diff --git a/zyx/lib/router/ext/__init__.py b/zyx/lib/router/ext/__init__.py index e0a944c..18c2c30 100644 --- a/zyx/lib/router/ext/__init__.py +++ b/zyx/lib/router/ext/__init__.py @@ -1,19 +1,10 @@ -__all__ = [ - "BaseModel", - "Field", - "app", - "image", - "audio", - "transcribe" -] +__all__ = ["BaseModel", "Field", "app", "image", "audio", "transcribe"] from .._router import router from ...types.base_model import BaseModel -from pydantic import ( - Field -) +from pydantic import Field class app(router): @@ -42,9 +33,3 @@ class transcribe(router): transcribe.init("zyx.resources.ext.multimodal", "transcribe") - - - - - - diff --git a/zyx/lib/router/ext/__init__.pyi b/zyx/lib/router/ext/__init__.pyi index 3b1ea9c..9ee7899 100644 --- a/zyx/lib/router/ext/__init__.pyi +++ b/zyx/lib/router/ext/__init__.pyi @@ -1,19 +1,9 @@ -__all__ = [ - "BaseModel", - "Field", - "app", - "image", - "audio", - "transcribe" -] - +__all__ = ["BaseModel", "Field", "app", "image", "audio", "transcribe"] from ...types.base_model import BaseModel as BaseModel -from pydantic import ( - Field as Field -) +from pydantic import Field as Field from ....resources.ext.app import terminal as app from litellm.main import image_generation as image from ....resources.ext.multimodal import audio as audio -from ....resources.ext.multimodal import transcribe as transcribe \ No newline at end of file +from ....resources.ext.multimodal import transcribe as transcribe diff --git a/zyx/lib/router/llm/__init__.py b/zyx/lib/router/llm/__init__.py index 012da4d..1c5d952 100644 --- a/zyx/lib/router/llm/__init__.py +++ b/zyx/lib/router/llm/__init__.py @@ -6,7 +6,7 @@ "extract", "function", "generate", - "system_prompt" + "system_prompt", ] diff --git a/zyx/lib/router/llm/__init__.pyi b/zyx/lib/router/llm/__init__.pyi index 3e5e524..9576b90 100644 --- a/zyx/lib/router/llm/__init__.pyi +++ b/zyx/lib/router/llm/__init__.pyi @@ -6,16 +6,14 @@ __all__ = [ "extract", "function", "generate", - "system_prompt" + "system_prompt", ] - from ....client import ( Client as Client, completion as completion, ) - from ....resources.completions.base.classify import classify as classify from ....resources.completions.base.code import code as code from ....resources.completions.base.extract import extract as extract diff --git a/zyx/lib/types/base_model.py b/zyx/lib/types/base_model.py index 31af1b9..4383521 100644 --- a/zyx/lib/types/base_model.py +++ b/zyx/lib/types/base_model.py @@ -1,13 +1,16 @@ from pydantic import create_model, BaseModel as PydanticBaseModel from typing import Optional, Literal, List, Type, TypeVar, Union, overload from ...client import InstructorMode, Client +from ..utils.logger import get_logger -T = TypeVar('T', bound='BaseModel') +logger = get_logger("base_model") -class BaseModel(PydanticBaseModel): +T = TypeVar("T", bound="BaseModel") + +class BaseModel(PydanticBaseModel): @overload @classmethod def generate( @@ -25,8 +28,7 @@ def generate( temperature: float = 0, mode: InstructorMode = "markdown_json_mode", verbose: bool = False, - ) -> List[T]: - ... + ) -> List[T]: ... @overload def generate( @@ -44,8 +46,7 @@ def generate( temperature: float = 0, mode: InstructorMode = "markdown_json_mode", verbose: bool = False, - ) -> List[T]: - ... + ) -> List[T]: ... @classmethod def generate( @@ -119,14 +120,26 @@ class User(BaseModel): if isinstance(cls_or_self, BaseModel): system_message += f"\n\nUse the following instance as a reference or starting point:\n{cls_or_self.model_dump_json()}" - user_message = instructions if instructions else f"Generate {n} instance(s) of the given model." + system_message += ( + f"\nALWAYS COMPLY WITH USER INSTRUCTIONS FOR CONTENT TOPICS & GUIDELINES." + ) + + user_message = ( + instructions + if instructions + else f"Generate {n} instance(s) of the given model." + ) + + if verbose: + logger.info(f"Template: {system_message}") + logger.info(f"Instructions: {user_message}") completion_client = Client( api_key=api_key, base_url=base_url, organization=organization, provider=client, - verbose=verbose + verbose=verbose, ) if process == "batch": @@ -140,7 +153,9 @@ class User(BaseModel): max_tokens=max_tokens, max_retries=max_retries, temperature=temperature, - mode="markdown_json_mode" if model.startswith(("ollama/", "ollama_chat/")) else mode, + mode="markdown_json_mode" + if model.startswith(("ollama/", "ollama_chat/")) + else mode, response_model=ResponseModel, ) return [response] if n == 1 else response.items @@ -157,18 +172,35 @@ class User(BaseModel): Field constraints: {field.json_schema_extra} Ensure that the generated value complies with the field's type and constraints. + + \nALWAYS COMPLY WITH USER INSTRUCTIONS FOR CONTENT TOPICS & GUIDELINES. """ - field_user_message = f"Generate a value for the '{field_name}' field." + field_user_message = ( + f"Generate a value for the '{field_name}' field." + ) + if instance: field_user_message += f"\nCurrent partial instance: {instance}" # Add information about previous generations if i > 0: - field_user_message += f"\n\nPrevious generations for this field:" - for j, prev_instance in enumerate(results[-min(3, i):], 1): - field_user_message += f"\n{j}. {getattr(prev_instance, field_name)}" + field_user_message += ( + f"\n\nPrevious generations for this field:" + ) + for j, prev_instance in enumerate(results[-min(3, i) :], 1): + field_user_message += ( + f"\n{j}. {getattr(prev_instance, field_name)}" + ) field_user_message += "\n\nPlease generate a different value from these previous ones." + field_user_message += f"""\n\n + USER INSTRUCTIONS DEFINED BELOW FOR CONTENT & GUIDELINES + + + {instructions if instructions else "No additional instructions provided."} + \n\n + """ + field_response = completion_client.completion( messages=[ {"role": "system", "content": field_system_message}, @@ -178,8 +210,12 @@ class User(BaseModel): max_tokens=max_tokens, max_retries=max_retries, temperature=temperature, - mode="markdown_json_mode" if model.startswith(("ollama/", "ollama_chat/")) else mode, - response_model=create_model("FieldResponse", value=(field.annotation, ...)), + mode="markdown_json_mode" + if model.startswith(("ollama/", "ollama_chat/")) + else mode, + response_model=create_model( + "FieldResponse", value=(field.annotation, ...) + ), ) instance[field_name] = field_response.value @@ -191,9 +227,10 @@ class User(BaseModel): if __name__ == "__main__": class TestData(BaseModel): - compounds : List[str] - + compounds: List[str] - compounds = TestData.generate("make me some data", n=5, process = "sequential", verbose = True) + compounds = TestData.generate( + "make me some data", n=5, process="sequential", verbose=True + ) print(compounds) diff --git a/zyx/lib/types/document.py b/zyx/lib/types/document.py index 226a49f..382d1c4 100644 --- a/zyx/lib/types/document.py +++ b/zyx/lib/types/document.py @@ -7,7 +7,7 @@ ) -T = TypeVar('T', bound=BaseModel) +T = TypeVar("T", bound=BaseModel) class Document(BaseModel): @@ -19,6 +19,7 @@ class Document(BaseModel): metadata (Dict[str, Any]): The metadata of the document. messages (Optional[List[Dict[str, Any]]]): The messages of the document. """ + content: Any metadata: Dict[str, Any] messages: Optional[List[Dict[str, Any]]] = [] @@ -49,7 +50,6 @@ def setup_messages(self): }, ] - def generate( self, target: Type[T], @@ -118,14 +118,18 @@ def generate( Use the document's content as context for generating these instances. Ensure that all generated instances comply with the model's schema and constraints. """ - user_message = instructions if instructions else f"Generate {n} instance(s) of the given model using the document's content as context." + user_message = ( + instructions + if instructions + else f"Generate {n} instance(s) of the given model using the document's content as context." + ) completion_client = Client( api_key=api_key, base_url=base_url, organization=organization, provider=client, - verbose=verbose + verbose=verbose, ) if process == "batch": @@ -138,7 +142,9 @@ def generate( max_tokens=max_tokens, max_retries=max_retries, temperature=temperature, - mode="markdown_json_mode" if model.startswith(("ollama/", "ollama_chat/")) else mode, + mode="markdown_json_mode" + if model.startswith(("ollama/", "ollama_chat/")) + else mode, response_model=ResponseModel, ) return [response] if n == 1 else response.items @@ -165,11 +171,15 @@ def generate( field_user_message = f"Generate a value for the '{field_name}' field using the document's content as context." if instance: field_user_message += f"\nCurrent partial instance: {instance}" - + if i > 0: - field_user_message += f"\n\nPrevious generations for this field:" - for j, prev_instance in enumerate(results[-min(3, i):], 1): - field_user_message += f"\n{j}. {getattr(prev_instance, field_name)}" + field_user_message += ( + f"\n\nPrevious generations for this field:" + ) + for j, prev_instance in enumerate(results[-min(3, i) :], 1): + field_user_message += ( + f"\n{j}. {getattr(prev_instance, field_name)}" + ) field_user_message += "\n\nPlease generate a different value from these previous ones." field_response = completion_client.completion( @@ -181,15 +191,18 @@ def generate( max_tokens=max_tokens, max_retries=max_retries, temperature=temperature, - mode="markdown_json_mode" if model.startswith(("ollama/", "ollama_chat/")) else mode, - response_model=create_model("FieldResponse", value=(field.annotation, ...)), + mode="markdown_json_mode" + if model.startswith(("ollama/", "ollama_chat/")) + else mode, + response_model=create_model( + "FieldResponse", value=(field.annotation, ...) + ), ) instance[field_name] = field_response.value - + results.append(target(**instance)) - + return results - def completion( self, diff --git a/zyx/lib/utils/logger.py b/zyx/lib/utils/logger.py index afb2144..1f54f1f 100644 --- a/zyx/lib/utils/logger.py +++ b/zyx/lib/utils/logger.py @@ -7,7 +7,6 @@ Author: Hammad Saeed """ - import builtins import logging @@ -19,16 +18,13 @@ builtins.print = rich_print -def get_logger( - module_name: str = __name__, - level: str = "INFO" -) -> logging.Logger: +def get_logger(module_name: str = __name__, level: str = "INFO") -> logging.Logger: logger = logging.getLogger(module_name) - + # Remove any existing handlers to avoid duplicates if logger.hasHandlers(): logger.handlers.clear() - + # Set the logging level logger.setLevel(level) @@ -39,7 +35,7 @@ def get_logger( show_time=True, omit_repeated_times=False, show_level=True, - show_path=False + show_path=False, ) # Set the format for the handler diff --git a/zyx/resources/completions/agents/conversation.py b/zyx/resources/completions/agents/conversation.py index dafa4d1..6d0afec 100644 --- a/zyx/resources/completions/agents/conversation.py +++ b/zyx/resources/completions/agents/conversation.py @@ -5,34 +5,40 @@ from ....lib.utils.logger import get_logger from ..agents.judge import judge, ValidationResult from ...ext.multimodal import OPENAI_TTS_VOICES, OPENAI_TTS_MODELS, audio -from ..base.classify import classify +from ..base.classify import classify logger = get_logger("conversation") + class Character(BaseModel): name: str personality: str knowledge: Optional[str] = None voice: Optional[OPENAI_TTS_VOICES] = None + class Message(BaseModel): role: Literal["user", "assistant"] content: str audio_file: Optional[str] = None + class Conversation(BaseModel): messages: List[Message] audio_file: Optional[str] = None + class EndConversation(BaseModel): should_end: bool explanation: Optional[str] = None confidence: Optional[float] = None + class ConversationEndCheck(BaseModel): should_end: bool explanation: Optional[str] = None + import tempfile import os from pydub import AudioSegment @@ -46,30 +52,36 @@ class ConversationEndCheck(BaseModel): logger = get_logger("conversation") + class Character(BaseModel): name: str personality: str knowledge: Optional[str] = None voice: Optional[OPENAI_TTS_VOICES] = None + class Message(BaseModel): role: Literal["user", "assistant"] content: str audio_file: Optional[str] = None + class Conversation(BaseModel): messages: List[Message] audio_file: Optional[str] = None + class EndConversation(BaseModel): should_end: bool explanation: Optional[str] = None confidence: Optional[float] = None + class ConversationEndCheck(BaseModel): should_end: bool explanation: Optional[str] = None + def conversation( instructions: Union[str, Document], characters: List[Character], @@ -88,7 +100,7 @@ def conversation( verbose: bool = False, generate_audio: bool = False, audio_model: OPENAI_TTS_MODELS = "tts-1", - audio_output_file: Optional[str] = None + audio_output_file: Optional[str] = None, ) -> Conversation: """ Generate a conversation between characters based on given instructions or a Document object, with optional validator. @@ -147,7 +159,7 @@ def conversation( base_url=base_url, organization=organization, provider=client, - verbose=verbose + verbose=verbose, ) conversation = Conversation(messages=[]) @@ -170,7 +182,9 @@ def conversation( for character in characters: if not character.voice: character.voice = available_voices.pop(0) - available_voices.append(character.voice) # Put it back at the end for reuse if needed + available_voices.append( + character.voice + ) # Put it back at the end for reuse if needed system_message = f""" You are simulating a conversation between the following characters: @@ -189,17 +203,17 @@ def conversation( # Create a temporary directory to store audio segments with tempfile.TemporaryDirectory() as temp_dir: logger.info(f"Created temporary directory: {temp_dir}") - + for turn in range(max_turns): current_character = characters[turn % len(characters)] user_message = f"Generate the next message for {current_character.name} in the conversation, focusing on the provided context." - + # Check if we've reached the maximum number of turns if turn == max_turns - 1: # Use the classifier to determine if the conversation should end classifier_result = classify( - inputs=' '.join([msg.content for msg in conversation.messages]), + inputs=" ".join([msg.content for msg in conversation.messages]), labels=["end", "continue"], classification="single", model=model, @@ -209,7 +223,7 @@ def conversation( mode=mode, temperature=temperature, client=client, - verbose=verbose + verbose=verbose, ) if verbose: @@ -221,15 +235,22 @@ def conversation( if classifier_result.label == "continue": # If the classifier says the conversation should not end, add a final summary prompt user_message = f"This is the final turn of the conversation. {current_character.name}, please summarize the key points discussed and provide a concluding statement to end the conversation." - + if end_check_attempts >= max_end_check_attempts: user_message += "\n\n[HIDDEN INSTRUCTION: The conversation should now conclude naturally. Provide a final statement or summary.]" response = completion_client.completion( messages=[ {"role": "system", "content": system_message}, - {"role": "user", "content": user_message} - ] + [{"role": msg.role, "content": f"{characters[i % len(characters)].name}: {msg.content}"} for i, msg in enumerate(conversation.messages)], + {"role": "user", "content": user_message}, + ] + + [ + { + "role": msg.role, + "content": f"{characters[i % len(characters)].name}: {msg.content}", + } + for i, msg in enumerate(conversation.messages) + ], model=model, response_model=Message, mode=mode, @@ -237,17 +258,22 @@ def conversation( temperature=temperature, ) - logger.info(f"Turn {turn + 1}: {current_character.name} - {response.content}") + logger.info( + f"Turn {turn + 1}: {current_character.name} - {response.content}" + ) if generate_audio: - temp_audio_file = os.path.join(temp_dir, f"{current_character.name.lower().replace(' ', '_')}_{turn}.mp3") + temp_audio_file = os.path.join( + temp_dir, + f"{current_character.name.lower().replace(' ', '_')}_{turn}.mp3", + ) logger.info(f"Attempting to generate audio file: {temp_audio_file}") try: # Remove the character's name from the beginning of the content audio_content = response.content if audio_content.startswith(f"{current_character.name}:"): audio_content = audio_content.split(":", 1)[1].strip() - + audio( prompt=audio_content, model=audio_model, @@ -258,13 +284,19 @@ def conversation( ) if os.path.exists(temp_audio_file): response.audio_file = temp_audio_file - logger.info(f"Successfully generated audio file: {temp_audio_file}") + logger.info( + f"Successfully generated audio file: {temp_audio_file}" + ) else: logger.warning(f"Audio file not created: {temp_audio_file}") logger.info(f"Current working directory: {os.getcwd()}") - logger.info(f"Temporary directory contents: {os.listdir(temp_dir)}") + logger.info( + f"Temporary directory contents: {os.listdir(temp_dir)}" + ) except Exception as e: - logger.warning(f"Failed to generate audio for turn {turn}: {str(e)}") + logger.warning( + f"Failed to generate audio for turn {turn}: {str(e)}" + ) logger.exception("Detailed error information:") conversation.messages.append(response) @@ -283,12 +315,17 @@ def conversation( max_retries=max_retries, organization=organization, client=client, - verbose=verbose + verbose=verbose, ) - if isinstance(validation_result, ValidationResult) and not validation_result.is_valid: + if ( + isinstance(validation_result, ValidationResult) + and not validation_result.is_valid + ): if verbose: - logger.warning(f"Message failed validation: {validation_result.explanation}") + logger.warning( + f"Message failed validation: {validation_result.explanation}" + ) continue # Check if we've reached the minimum number of turns @@ -296,8 +333,14 @@ def conversation( # Use the boolean BaseModel for end-of-conversation detection end_check = completion_client.completion( messages=[ - {"role": "system", "content": f"You are evaluating if a conversation should end based on the following criteria: {end_criteria}"}, - {"role": "user", "content": f"Analyze the following conversation and determine if it should end:\n\n{' '.join([msg.content for msg in conversation.messages])}"} + { + "role": "system", + "content": f"You are evaluating if a conversation should end based on the following criteria: {end_criteria}", + }, + { + "role": "user", + "content": f"Analyze the following conversation and determine if it should end:\n\n{' '.join([msg.content for msg in conversation.messages])}", + }, ], model=model, response_model=ConversationEndCheck, @@ -317,7 +360,7 @@ def conversation( # Use the classify function to determine if the conversation should end classifier_result = classify( - inputs=' '.join([msg.content for msg in conversation.messages]), + inputs=" ".join([msg.content for msg in conversation.messages]), labels=["end", "continue"], classification="single", model=model, @@ -327,7 +370,7 @@ def conversation( mode=mode, temperature=temperature, client=client, - verbose=verbose + verbose=verbose, ) if verbose: @@ -353,7 +396,9 @@ def conversation( else: logger.warning(f"Audio file not found: {msg.audio_file}") except Exception as e: - logger.warning(f"Error processing audio file {msg.audio_file}: {str(e)}") + logger.warning( + f"Error processing audio file {msg.audio_file}: {str(e)}" + ) if not combined.empty(): combined.export(audio_output_file, format="mp3") @@ -364,33 +409,46 @@ def conversation( return conversation + if __name__ == "__main__": # Example usage with a Document object from ....lib.types.document import Document doc = Document( content="The impact of artificial intelligence on job markets has been a topic of intense debate. While AI has the potential to automate many tasks and potentially displace some jobs, it also has the capacity to create new job opportunities and enhance productivity in various sectors. The key challenge lies in managing this transition and ensuring that the workforce is adequately prepared for the changes ahead.", - metadata={"type": "research_summary", "topic": "AI and Employment"} + metadata={"type": "research_summary", "topic": "AI and Employment"}, ) result = conversation( instructions=doc, characters=[ - Character(name="AI Researcher", personality="Optimistic about AI's potential to create new job opportunities", voice="nova"), - Character(name="Labor Economist", personality="Concerned about potential job displacement due to AI", voice="onyx"), - Character(name="Podcast Host", personality="Neutral moderator, asks probing questions to both guests", voice="echo") + Character( + name="AI Researcher", + personality="Optimistic about AI's potential to create new job opportunities", + voice="nova", + ), + Character( + name="Labor Economist", + personality="Concerned about potential job displacement due to AI", + voice="onyx", + ), + Character( + name="Podcast Host", + personality="Neutral moderator, asks probing questions to both guests", + voice="echo", + ), ], min_turns=12, max_turns=20, end_criteria="The podcast should conclude when both guests have shared their final thoughts and the host has summarized the key points of the discussion", verbose=True, generate_audio=True, - audio_output_file="ai_job_market_podcast.mp3" + audio_output_file="ai_job_market_podcast.mp3", ) print("\nGenerated Podcast Conversation:") for msg in result.messages: print(f"{msg.role.capitalize()}: {msg.content}") - + if result.audio_file: - print(f"\nFull conversation audio saved to: {result.audio_file}") \ No newline at end of file + print(f"\nFull conversation audio saved to: {result.audio_file}") diff --git a/zyx/resources/completions/agents/judge.py b/zyx/resources/completions/agents/judge.py index f554a7c..006e1f1 100644 --- a/zyx/resources/completions/agents/judge.py +++ b/zyx/resources/completions/agents/judge.py @@ -2,10 +2,7 @@ from pydantic import BaseModel, Field from ....lib.utils.logger import get_logger -from ....client import ( - Client, - InstructorMode -) +from ....client import Client, InstructorMode logger = get_logger("judge") @@ -15,18 +12,22 @@ class JudgmentResult(BaseModel): explanation: str verdict: str + class ValidationResult(BaseModel): is_valid: bool explanation: str + class RegeneratedResponse(BaseModel): response: str + class FactCheckResult(BaseModel): is_accurate: bool explanation: str confidence: float = Field(..., ge=0.0, le=1.0) + def judge( prompt: str, responses: Optional[Union[List[str], str]] = None, @@ -42,7 +43,7 @@ def judge( organization: Optional[str] = None, client: Optional[Literal["openai", "litellm"]] = None, verbose: bool = False, - guardrails: Optional[Union[str, List[str]]] = None + guardrails: Optional[Union[str, List[str]]] = None, ) -> Union[JudgmentResult, ValidationResult, RegeneratedResponse, FactCheckResult]: """ Judge responses based on accuracy, validate against a schema, or fact-check a single response, @@ -110,7 +111,7 @@ def judge( base_url=base_url, organization=organization, provider=client, - verbose=verbose + verbose=verbose, ) if process == "accuracy": @@ -122,11 +123,11 @@ def judge( user_message = f"Prompt: {prompt}\n\nResponses:\n" for idx, response in enumerate(responses, 1): user_message += f"{idx}. {response}\n\n" - + result = completion_client.completion( messages=[ {"role": "system", "content": system_message}, - {"role": "user", "content": user_message} + {"role": "user", "content": user_message}, ], model=model, response_model=JudgmentResult, @@ -134,9 +135,8 @@ def judge( max_retries=max_retries, temperature=temperature, ) - - if regenerate: + if regenerate: if verbose: logger.warning(f"Response is not accurate. Regenerating response.") @@ -145,11 +145,11 @@ def judge( "that addresses the prompt more effectively than the original responses." ) user_message = f"Original prompt: {prompt}\n\nJudgment: {result.explanation}\n\nGenerate an optimized response:" - + regenerated = completion_client.completion( messages=[ {"role": "system", "content": system_message}, - {"role": "user", "content": user_message} + {"role": "user", "content": user_message}, ], model=model, response_model=RegeneratedResponse, @@ -158,22 +158,22 @@ def judge( temperature=temperature, ) result = regenerated - + elif process == "validate": if not schema: raise ValueError("Schema is required for validation.") - + system_message = ( "You are a validation expert. Your task is to determine if the given response " "matches the provided schema or instructions. Provide a detailed explanation " "of your validation process and state whether the response is valid or not." ) user_message = f"Prompt: {prompt}\n\nResponse: {responses[0]}\n\nSchema/Instructions: {schema}" - + result = completion_client.completion( messages=[ {"role": "system", "content": system_message}, - {"role": "user", "content": user_message} + {"role": "user", "content": user_message}, ], model=model, response_model=ValidationResult, @@ -181,9 +181,8 @@ def judge( max_retries=max_retries, temperature=temperature, ) - - if regenerate and not result.is_valid: + if regenerate and not result.is_valid: if verbose: logger.warning(f"Response is not valid. Regenerating response.") @@ -192,11 +191,11 @@ def judge( "correctly adheres to the given schema or instructions." ) user_message = f"Original prompt: {prompt}\n\nSchema/Instructions: {schema}\n\nGenerate a valid response:" - + regenerated = completion_client.completion( messages=[ {"role": "system", "content": system_message}, - {"role": "user", "content": user_message} + {"role": "user", "content": user_message}, ], model=model, response_model=RegeneratedResponse, @@ -205,13 +204,13 @@ def judge( temperature=temperature, ) result = regenerated - + elif process == "fact_check": if responses is None: responses = [prompt] # Use the prompt as the response for fact-checking elif len(responses) != 1: raise ValueError("Fact-check requires exactly one response.") - + system_message = ( "You are a fact-checking expert. Your task is to determine if the given response " "is accurate based on the prompt and your knowledge. Provide a detailed explanation " @@ -221,11 +220,11 @@ def judge( user_message = f"Prompt: {prompt}\n\nResponse to fact-check: {responses[0]}" if schema: user_message += f"\n\nAdditional fact-checking guidelines: {schema}" - + result = completion_client.completion( messages=[ {"role": "system", "content": system_message}, - {"role": "user", "content": user_message} + {"role": "user", "content": user_message}, ], model=model, response_model=FactCheckResult, @@ -233,9 +232,8 @@ def judge( max_retries=max_retries, temperature=temperature, ) - - if regenerate and not result.is_accurate: + if regenerate and not result.is_accurate: if verbose: logger.warning(f"Response is not accurate. Regenerating response.") @@ -244,11 +242,11 @@ def judge( "is accurate and addresses the original prompt correctly." ) user_message = f"Original prompt: {prompt}\n\nFact-check result: {result.explanation}\n\nGenerate an accurate response:" - + regenerated = completion_client.completion( messages=[ {"role": "system", "content": system_message}, - {"role": "user", "content": user_message} + {"role": "user", "content": user_message}, ], model=model, response_model=RegeneratedResponse, @@ -257,37 +255,70 @@ def judge( temperature=temperature, ) result = regenerated - + else: - raise ValueError("Invalid process. Choose 'accuracy', 'validate', or 'fact_check'.") + raise ValueError( + "Invalid process. Choose 'accuracy', 'validate', or 'fact_check'." + ) # Add guardrails check after the main process if guardrails: - guardrails_result = check_guardrails(prompt, result, guardrails, completion_client, model, mode, max_retries, temperature, verbose) + guardrails_result = check_guardrails( + prompt, + result, + guardrails, + completion_client, + model, + mode, + max_retries, + temperature, + verbose, + ) if not guardrails_result.passed: if verbose: logger.warning(f"Response violates guardrails. Regenerating response.") - result = regenerate_response(prompt, guardrails_result.explanation, completion_client, model, mode, max_retries, temperature) + result = regenerate_response( + prompt, + guardrails_result.explanation, + completion_client, + model, + mode, + max_retries, + temperature, + ) return result -def check_guardrails(prompt, result, guardrails, completion_client, model, mode, max_retries, temperature, verbose): + +def check_guardrails( + prompt, + result, + guardrails, + completion_client, + model, + mode, + max_retries, + temperature, + verbose, +): if isinstance(guardrails, str): guardrails = [guardrails] - + guardrails_prompt = "\n".join(guardrails) - + system_message = ( "You are a content moderator. Your task is to determine if the given response " "violates any of the specified guardrails. Provide a detailed explanation " "of your evaluation process and state whether the response passes all guardrails or not." ) - user_message = f"Prompt: {prompt}\n\nResponse: {result}\n\nGuardrails:\n{guardrails_prompt}" - + user_message = ( + f"Prompt: {prompt}\n\nResponse: {result}\n\nGuardrails:\n{guardrails_prompt}" + ) + guardrails_result = completion_client.completion( messages=[ {"role": "system", "content": system_message}, - {"role": "user", "content": user_message} + {"role": "user", "content": user_message}, ], model=model, response_model=GuardrailsResult, @@ -295,20 +326,23 @@ def check_guardrails(prompt, result, guardrails, completion_client, model, mode, max_retries=max_retries, temperature=temperature, ) - + return guardrails_result -def regenerate_response(prompt, explanation, completion_client, model, mode, max_retries, temperature): + +def regenerate_response( + prompt, explanation, completion_client, model, mode, max_retries, temperature +): system_message = ( "Based on the guardrails violation, generate a new response that " "addresses the original prompt while adhering to all specified guardrails." ) user_message = f"Original prompt: {prompt}\n\nGuardrails violation: {explanation}\n\nGenerate a compliant response:" - + regenerated = completion_client.completion( messages=[ {"role": "system", "content": system_message}, - {"role": "user", "content": user_message} + {"role": "user", "content": user_message}, ], model=model, response_model=RegeneratedResponse, @@ -318,32 +352,42 @@ def regenerate_response(prompt, explanation, completion_client, model, mode, max ) return regenerated + class GuardrailsResult(BaseModel): passed: bool explanation: str + if __name__ == "__main__": # Example usage prompt = "Explain the concept of quantum entanglement." responses = [ "Quantum entanglement is a phenomenon where two particles become interconnected and their quantum states cannot be described independently.", - "Quantum entanglement is when particles are really close to each other and move in the same way." + "Quantum entanglement is when particles are really close to each other and move in the same way.", ] - + # Accuracy judgment result = judge(prompt, responses, process="accuracy", verbose=True) - print(f"Accuracy Judgment:\nExplanation: {result.explanation}\nVerdict: {result.verdict}") - + print( + f"Accuracy Judgment:\nExplanation: {result.explanation}\nVerdict: {result.verdict}" + ) + # Validation schema = "The response should include: 1) Definition of quantum entanglement, 2) Its importance in quantum mechanics, 3) An example or application." - result = judge(prompt, [responses[0]], process="validate", schema=schema, verbose=True) - print(f"\nValidation Result:\nIs Valid: {result.is_valid}\nExplanation: {result.explanation}") - + result = judge( + prompt, [responses[0]], process="validate", schema=schema, verbose=True + ) + print( + f"\nValidation Result:\nIs Valid: {result.is_valid}\nExplanation: {result.explanation}" + ) + # Fact-check fact_check_response = "Quantum entanglement occurs when two particles are separated by a large distance but still instantaneously affect each other's quantum states." result = judge(prompt, [fact_check_response], process="fact_check", verbose=True) - print(f"\nFact-Check Result:\nIs Accurate: {result.is_accurate}\nExplanation: {result.explanation}\nConfidence: {result.confidence}") - + print( + f"\nFact-Check Result:\nIs Accurate: {result.is_accurate}\nExplanation: {result.explanation}\nConfidence: {result.confidence}" + ) + # Regeneration result = judge(prompt, responses, process="accuracy", regenerate=True, verbose=True) - print(f"\nRegenerated Response:\n{result.response}") \ No newline at end of file + print(f"\nRegenerated Response:\n{result.response}") diff --git a/zyx/resources/completions/agents/plan.py b/zyx/resources/completions/agents/plan.py index 133658f..61aea70 100644 --- a/zyx/resources/completions/agents/plan.py +++ b/zyx/resources/completions/agents/plan.py @@ -1,20 +1,20 @@ from pydantic import BaseModel, create_model, Field from typing import Optional, List, Union, Literal, Type, Any from ....lib.utils.logger import get_logger -from ....client import ( - Client, - InstructorMode -) +from ....client import Client, InstructorMode logger = get_logger("plan") + class Task(BaseModel): description: str details: Optional[str] = None + class Plan(BaseModel): tasks: List[Task] + def plan( input: Union[str, Type[BaseModel]], instructions: Optional[str] = None, @@ -77,7 +77,7 @@ def plan( base_url=base_url, organization=organization, provider=client, - verbose=verbose + verbose=verbose, ) if isinstance(input, str): @@ -89,13 +89,15 @@ def plan( else: raise ValueError("Input must be either a string or a Pydantic model class.") - user_message = instructions if instructions else f"Generate a plan with {steps} steps." + user_message = ( + instructions if instructions else f"Generate a plan with {steps} steps." + ) if process == "single" or n == 1: result = completion_client.completion( messages=[ {"role": "system", "content": system_message}, - {"role": "user", "content": user_message} + {"role": "user", "content": user_message}, ], model=model, response_model=response_model, @@ -105,18 +107,20 @@ def plan( ) return result else: # batch process - batch_response_model = create_model("ResponseModel", items=(List[response_model], ...)) + batch_response_model = create_model( + "ResponseModel", items=(List[response_model], ...) + ) results = [] for i in range(0, n, batch_size): batch_n = min(batch_size, n - i) batch_message = f"Generate {batch_n} plans, each with {steps} steps." if results: batch_message += f"\nPreviously generated plans: {results[-3:]}\nEnsure these new plans are different." - + result = completion_client.completion( messages=[ {"role": "system", "content": system_message}, - {"role": "user", "content": batch_message} + {"role": "user", "content": batch_message}, ], model=model, response_model=batch_response_model, @@ -124,11 +128,12 @@ def plan( max_retries=max_retries, temperature=temperature, ) - + results.extend(result.items) - + return results + def _get_string_system_message(input: str, steps: int) -> str: return f""" You are a planning assistant using the Tree of Thoughts method. Your task is to generate a detailed plan based on the given input. @@ -140,6 +145,7 @@ def _get_string_system_message(input: str, steps: int) -> str: Return the tasks as a Plan object with a list of Task objects. """ + def _get_model_system_message(input_model: Type[BaseModel], steps: int) -> str: return f""" You are a planning assistant using the Tree of Thoughts method. Your task is to generate a detailed plan based on the given Pydantic model. @@ -156,6 +162,7 @@ def _get_model_system_message(input_model: Type[BaseModel], steps: int) -> str: Return the tasks as a Plan object with a list of Task objects. """ + if __name__ == "__main__": # Example usage with string input goal = "Create a marketing strategy for a new smartphone" @@ -184,7 +191,9 @@ class ResearchTask(BaseModel): print() # Batch processing example - batch_results = plan(ResearchTask, n=2, process="batch", batch_size=2, steps=3, verbose=True) + batch_results = plan( + ResearchTask, n=2, process="batch", batch_size=2, steps=3, verbose=True + ) print("Batch plans for Pydantic model input:") for i, plan in enumerate(batch_results, 1): print(f"Plan {i}:") @@ -192,4 +201,4 @@ class ResearchTask(BaseModel): print(f"- Topic: {task.topic}") print(f" Resources: {', '.join(task.resources)}") print(f" Estimated Time: {task.estimated_time}") - print() \ No newline at end of file + print() diff --git a/zyx/resources/completions/agents/query.py b/zyx/resources/completions/agents/query.py index 9ca4522..5add513 100644 --- a/zyx/resources/completions/agents/query.py +++ b/zyx/resources/completions/agents/query.py @@ -4,13 +4,11 @@ from typing import List, Optional, Union, Literal, Callable from ....lib.utils.logger import get_logger -from ....client import ( - Client, - InstructorMode -) +from ....client import Client, InstructorMode logger = get_logger("workflow") + class EnumAgentRoles(Enum): SUPERVISOR = "supervisor" PLANNER = "planner" @@ -21,6 +19,7 @@ class EnumAgentRoles(Enum): TOOL = "tool" RETRIEVER = "retriever" + class EnumWorkflowState(Enum): IDLE = "idle" CHAT = "chat" @@ -33,15 +32,18 @@ class EnumWorkflowState(Enum): USING_TOOL = "using_tool" RETRIEVING = "retrieving" + class Task(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) description: str + class Plan(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) goal: str tasks: List[Task] = Field(default_factory=list) + class Workflow(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) user_goal: Optional[str] = None @@ -52,10 +54,12 @@ class Workflow(BaseModel): completed_tasks: List[Task] = Field(default_factory=list) task_queue: List[Task] = Field(default_factory=list) + class UserIntent(BaseModel): intent: str confidence: float + class QueryParams(BaseModel): model: str = "gpt-4o-mini" api_key: Optional[str] = None @@ -76,9 +80,11 @@ class QueryParams(BaseModel): stream: Optional[bool] = False verbose: Optional[bool] = False + class TaskCompletionCheck(BaseModel): is_complete: bool + def classify_intent(user_input: str, params: QueryParams) -> UserIntent: intent_labels = [ "chat", @@ -98,7 +104,7 @@ def classify_intent(user_input: str, params: QueryParams) -> UserIntent: base_url=params.base_url, organization=params.organization, provider="openai", - verbose=params.verbose + verbose=params.verbose, ) classification = completion_client.completion( @@ -114,6 +120,7 @@ def classify_intent(user_input: str, params: QueryParams) -> UserIntent: ) return classification + def generate_plan(goal: str, params: QueryParams) -> Plan: system_message = f"Generate a plan for the following goal: {goal}" user_message = goal @@ -122,7 +129,7 @@ def generate_plan(goal: str, params: QueryParams) -> Plan: api_key=params.api_key, base_url=params.base_url, organization=params.organization, - verbose=params.verbose + verbose=params.verbose, ) plan_response = completion_client.completion( @@ -138,6 +145,7 @@ def generate_plan(goal: str, params: QueryParams) -> Plan: ) return plan_response + def execute_task(task: Task, params: QueryParams) -> str: system_message = f"Execute the following task: {task.description}" user_message = task.description @@ -147,7 +155,7 @@ def execute_task(task: Task, params: QueryParams) -> str: base_url=params.base_url, organization=params.organization, provider="openai", - verbose=params.verbose + verbose=params.verbose, ) execute_response = completion_client.completion( @@ -162,6 +170,7 @@ def execute_task(task: Task, params: QueryParams) -> str: ) return execute_response.choices[0].message.content + def check_task_completion(task: Task, result: str, params: QueryParams) -> bool: system_message = f"Check if the following task is complete: {task.description}" user_message = result @@ -171,7 +180,7 @@ def check_task_completion(task: Task, result: str, params: QueryParams) -> bool: base_url=params.base_url, organization=params.organization, provider="openai", - verbose=params.verbose + verbose=params.verbose, ) check_response = completion_client.completion( @@ -187,6 +196,7 @@ def check_task_completion(task: Task, result: str, params: QueryParams) -> bool: ) return check_response.is_complete + def query( prompt: str, model: str = "gpt-4o-mini", @@ -233,7 +243,7 @@ def query( api_key=params.api_key, base_url=params.base_url, organization=params.organization, - verbose=params.verbose + verbose=params.verbose, ) workflow = Workflow() @@ -280,7 +290,7 @@ def query( temperature=params.temperature, ) return reflection_response.choices[0].message.content - + workflow.state = EnumWorkflowState.COMPLETING final_summary = completion_client.completion( messages=[ @@ -317,5 +327,6 @@ def query( ) return response.choices[0].message.content + if __name__ == "__main__": - print(query("I want to learn how to code", verbose=True)) \ No newline at end of file + print(query("I want to learn how to code", verbose=True)) diff --git a/zyx/resources/completions/agents/scrape.py b/zyx/resources/completions/agents/scrape.py index c4028c9..aa42173 100644 --- a/zyx/resources/completions/agents/scrape.py +++ b/zyx/resources/completions/agents/scrape.py @@ -4,11 +4,7 @@ from pydantic import BaseModel, Field from enum import Enum -from ....client import ( - Client, - InstructorMode, - ToolType -) +from ....client import Client, InstructorMode, ToolType from ....lib.types.document import Document @@ -42,11 +38,13 @@ class ScrapingStep(Enum): EVALUATE = "evaluate" REFINE = "refine" + class StepResult(BaseModel): is_successful: bool explanation: str content: Optional[str] = None + class ScrapeWorkflow(BaseModel): query: str current_step: ScrapingStep = ScrapingStep.SEARCH @@ -55,6 +53,7 @@ class ScrapeWorkflow(BaseModel): summary: Optional[str] = None evaluation: Optional[StepResult] = None + def scrape( query: str, max_results: Optional[int] = 5, @@ -89,7 +88,7 @@ def scrape( temperature: The temperature to use for completion. run_tools: Whether to run tools for completion. tools: The tools to use for completion. - + Returns: A Document object containing the summary and metadata. @@ -98,10 +97,7 @@ def scrape( from bs4 import BeautifulSoup client = Client( - api_key = api_key, - base_url = base_url, - provider = client, - verbose = verbose + api_key=api_key, base_url=base_url, provider=client, verbose=verbose ) workflow = ScrapeWorkflow(query=query) @@ -162,7 +158,7 @@ def tag_visible(element): content = future.result() if content: contents.append(content) - + workflow.fetched_contents = contents if verbose: @@ -214,7 +210,7 @@ def tag_visible(element): f"Summary:\n{summary}\n\n" "Provide an explanation of your evaluation and determine if the summary is successful or needs refinement." ) - + evaluation_response = client.completion( messages=[ {"role": "system", "content": "You are an expert evaluator of summaries."}, @@ -226,7 +222,7 @@ def tag_visible(element): max_retries=max_retries, temperature=temperature, ) - + workflow.evaluation = evaluation_response # Step 7: Refine if necessary @@ -238,10 +234,13 @@ def tag_visible(element): f"Evaluation feedback:\n{evaluation_response.explanation}\n\n" "Please provide an improved and refined summary addressing the feedback." ) - + refined_response = client.completion( messages=[ - {"role": "system", "content": "You are an expert at refining and improving summaries."}, + { + "role": "system", + "content": "You are an expert at refining and improving summaries.", + }, {"role": "user", "content": refine_prompt}, ], model=model, @@ -249,7 +248,7 @@ def tag_visible(element): max_retries=max_retries, temperature=temperature, ) - + summary = refined_response.choices[0].message.content if verbose: @@ -280,4 +279,4 @@ def tag_visible(element): verbose=True, ) print("Final Document:") - print(result_document.content) \ No newline at end of file + print(result_document.content) diff --git a/zyx/resources/completions/agents/solve.py b/zyx/resources/completions/agents/solve.py index 44c8b3c..1cea7f3 100644 --- a/zyx/resources/completions/agents/solve.py +++ b/zyx/resources/completions/agents/solve.py @@ -1,39 +1,38 @@ from pydantic import BaseModel -from typing import ( - List, - Literal, - Optional, - Union -) +from typing import List, Literal, Optional, Union from ....lib.utils.logger import get_logger -from ....client import ( - Client, - InstructorMode -) +from ....client import Client, InstructorMode logger = get_logger("solve") + class Thought(BaseModel): content: str score: float + class Thoughts(BaseModel): thoughts: List[Thought] + class HighLevelConcept(BaseModel): concept: str + class FinalAnswer(BaseModel): answer: str + class TreeNode(BaseModel): thought: Thought children: List["TreeNode"] = [] + class TreeOfThoughtResult(BaseModel): final_answer: str reasoning_tree: TreeNode + def solve( problem: str, use_high_level_concept: bool = False, @@ -91,7 +90,7 @@ def solve( base_url=base_url, organization=organization, provider=client, - verbose=verbose + verbose=verbose, ) if use_high_level_concept: @@ -108,12 +107,14 @@ def solve( - Do not hallucinate or make up a concept. """ - user_message = f"Problem: {problem}\nProvide a high-level concept related to this problem:" + user_message = ( + f"Problem: {problem}\nProvide a high-level concept related to this problem:" + ) high_level_concept_response = completion_client.completion( messages=[ {"role": "system", "content": system_message}, - {"role": "user", "content": user_message} + {"role": "user", "content": user_message}, ], model=model, response_model=HighLevelConcept, @@ -142,7 +143,7 @@ def solve( final_answer_response = completion_client.completion( messages=[ {"role": "system", "content": system_message}, - {"role": "user", "content": user_message} + {"role": "user", "content": user_message}, ], model=model, response_model=FinalAnswer, @@ -154,6 +155,7 @@ def solve( return final_answer_response if use_tree_of_thought: + def generate_thoughts(current_problem: str, depth: int) -> TreeNode: if depth >= max_depth: return TreeNode(thought=Thought(content="Reached max depth", score=0)) @@ -172,7 +174,7 @@ def generate_thoughts(current_problem: str, depth: int) -> TreeNode: response = completion_client.completion( messages=[ {"role": "system", "content": system_message}, - {"role": "user", "content": user_message} + {"role": "user", "content": user_message}, ], model=model, response_model=Thoughts, @@ -210,7 +212,7 @@ def generate_thoughts(current_problem: str, depth: int) -> TreeNode: final_answer_response = completion_client.completion( messages=[ {"role": "system", "content": system_message}, - {"role": "user", "content": user_message} + {"role": "user", "content": user_message}, ], model=model, response_model=FinalAnswer, @@ -219,7 +221,9 @@ def generate_thoughts(current_problem: str, depth: int) -> TreeNode: temperature=temperature, ) - return TreeOfThoughtResult(final_answer=final_answer_response.answer, reasoning_tree=root) + return TreeOfThoughtResult( + final_answer=final_answer_response.answer, reasoning_tree=root + ) # Default to chain-of-thought approach system_message = f""" @@ -232,12 +236,14 @@ def generate_thoughts(current_problem: str, depth: int) -> TreeNode: - Do not hallucinate or make up information. """ - user_message = f"Problem: {problem}\nGenerate a sequence of thoughts to solve the problem:" + user_message = ( + f"Problem: {problem}\nGenerate a sequence of thoughts to solve the problem:" + ) response = completion_client.completion( messages=[ {"role": "system", "content": system_message}, - {"role": "user", "content": user_message} + {"role": "user", "content": user_message}, ], model=model, response_model=Thoughts, @@ -246,17 +252,20 @@ def generate_thoughts(current_problem: str, depth: int) -> TreeNode: temperature=temperature, ) - final_answer = response.thoughts[-1].content # Assuming the last thought is the final answer + final_answer = response.thoughts[ + -1 + ].content # Assuming the last thought is the final answer return FinalAnswer(answer=final_answer) + if __name__ == "__main__": result = solve( "What is the significance of the Pythagorean theorem in mathematics?", verbose=True, use_high_level_concept=True, use_tree_of_thought=True, - batch_size = 1, + batch_size=1, ) print(f"Final answer: {result.answer}") if isinstance(result, TreeOfThoughtResult): diff --git a/zyx/resources/completions/base/classify.py b/zyx/resources/completions/base/classify.py index fd1573e..4f8007e 100644 --- a/zyx/resources/completions/base/classify.py +++ b/zyx/resources/completions/base/classify.py @@ -1,36 +1,27 @@ from ....lib.utils.logger import get_logger -from ....client import ( - Client, - InstructorMode -) +from ....client import Client, InstructorMode from pydantic import BaseModel, create_model -from typing import ( - List, - Literal, - Optional, - Union -) +from typing import List, Literal, Optional, Union logger = get_logger("classify") def classify( - inputs : Union[str, List[str]], - labels : List[str], - classification : Literal["single", "multi"] = "single", - n : int = 1, - batch_size : int = 3, - model : str = "gpt-4o-mini", - api_key : Optional[str] = None, - base_url : Optional[str] = None, - organization : Optional[str] = None, - mode : InstructorMode = "tool_call", - temperature : Optional[float] = None, - client : Optional[Literal["openai", "litellm"]] = None, - verbose : bool = False + inputs: Union[str, List[str]], + labels: List[str], + classification: Literal["single", "multi"] = "single", + n: int = 1, + batch_size: int = 3, + model: str = "gpt-4o-mini", + api_key: Optional[str] = None, + base_url: Optional[str] = None, + organization: Optional[str] = None, + mode: InstructorMode = "tool_call", + temperature: Optional[float] = None, + client: Optional[Literal["openai", "litellm"]] = None, + verbose: bool = False, ) -> List: - """ Classifies given input(s) into one or more of the provided labels. @@ -67,12 +58,12 @@ def classify( logger.info(f"Classification Mode: {classification}") class ClassificationResult(BaseModel): - text : str - label : str + text: str + label: str class MultiClassificationResult(BaseModel): - text : str - labels : List[str] + text: str + labels: List[str] if classification == "single": system_message = f""" @@ -86,8 +77,7 @@ class MultiClassificationResult(BaseModel): response_model = ClassificationResult elif batch_size > 1: response_model = create_model( - "ClassificationResult", - items = (List[ClassificationResult], ...) + "ClassificationResult", items=(List[ClassificationResult], ...) ) else: raise ValueError("Batch size must be a positive integer.") @@ -103,28 +93,27 @@ class MultiClassificationResult(BaseModel): response_model = MultiClassificationResult elif batch_size > 1: response_model = create_model( - "ClassificationResult", - items = (List[MultiClassificationResult], ...) + "ClassificationResult", items=(List[MultiClassificationResult], ...) ) else: raise ValueError("Batch size must be a positive integer.") - + if isinstance(inputs, str): inputs = [inputs] results = [] completion_client = Client( - api_key = api_key, - base_url = base_url, - organization = organization, - provider = client, - verbose = verbose + api_key=api_key, + base_url=base_url, + organization=organization, + provider=client, + verbose=verbose, ) for i in range(0, len(inputs), batch_size): - batch = inputs[i:i+batch_size] - + batch = inputs[i : i + batch_size] + user_message = "Classify the following text(s):\n\n" for idx, text in enumerate(batch, 1): user_message += f"{idx}. {text}\n\n" @@ -132,7 +121,7 @@ class MultiClassificationResult(BaseModel): result = completion_client.completion( messages=[ {"role": "system", "content": system_message}, - {"role": "user", "content": user_message} + {"role": "user", "content": user_message}, ], model=model, response_model=response_model, @@ -147,19 +136,14 @@ class MultiClassificationResult(BaseModel): return results if len(results) > 1 else results[0] - -if __name__ == "__main__": +if __name__ == "__main__": items = [ "I love programming in Python", "I like french fries", - "I love programming in Julia" + "I love programming in Julia", ] labels = ["code", "food"] - print(classify(items, labels, classification = "single", batch_size = 2, verbose = True)) - - - - + print(classify(items, labels, classification="single", batch_size=2, verbose=True)) diff --git a/zyx/resources/completions/base/code.py b/zyx/resources/completions/base/code.py index f5fd2c5..cafe6d8 100644 --- a/zyx/resources/completions/base/code.py +++ b/zyx/resources/completions/base/code.py @@ -1,14 +1,7 @@ from ....lib.utils.logger import get_logger -from ....client import ( - Client, - InstructorMode -) +from ....client import Client, InstructorMode from pydantic import BaseModel, Field -from typing import ( - Any, - Literal, - Optional -) +from typing import Any, Literal, Optional import traceback import tempfile import sys @@ -28,9 +21,8 @@ def code( temperature: Optional[float] = None, client: Optional[Literal["openai", "litellm"]] = None, verbose: bool = False, - **kwargs + **kwargs, ) -> Any: - """ Generates, executes and returns results of python code. """ @@ -58,20 +50,20 @@ class CodeGenerationModel(BaseModel): base_url=base_url, organization=organization, provider=client, - verbose=verbose + verbose=verbose, ) try: response = completion_client.completion( messages=[ {"role": "system", "content": system_message}, - {"role": "user", "content": user_message} + {"role": "user", "content": user_message}, ], model=model, response_model=CodeGenerationModel, mode=mode, temperature=temperature, - **kwargs + **kwargs, ) if verbose: @@ -97,10 +89,10 @@ class CodeGenerationModel(BaseModel): exec(response.code, {}, local_namespace) # Return the result object - if 'result' not in local_namespace: + if "result" not in local_namespace: raise ValueError("No result object found in the generated code.") - return local_namespace['result'] - + return local_namespace["result"] + finally: # Clean up: remove the temporary file os.unlink(temp_file_path) @@ -111,9 +103,13 @@ class CodeGenerationModel(BaseModel): print(f"Traceback: {traceback.format_exc()}") raise + if __name__ == "__main__": # Generate a logger object - generated_logger = code("create a logger named 'my_logger' that logs to console with INFO level", verbose=True) + generated_logger = code( + "create a logger named 'my_logger' that logs to console with INFO level", + verbose=True, + ) # Use the generated logger - generated_logger.info("This is a test log message") \ No newline at end of file + generated_logger.info("This is a test log message") diff --git a/zyx/resources/completions/base/extract.py b/zyx/resources/completions/base/extract.py index 16ea594..bdaaf85 100644 --- a/zyx/resources/completions/base/extract.py +++ b/zyx/resources/completions/base/extract.py @@ -1,16 +1,7 @@ from ....lib.utils.logger import get_logger -from ....client import ( - Client, - InstructorMode -) +from ....client import Client, InstructorMode from pydantic import BaseModel, create_model -from typing import ( - List, - Literal, - Optional, - Type, - Union -) +from typing import List, Literal, Optional, Type, Union logger = get_logger("extract") @@ -41,7 +32,7 @@ def extract( class User(BaseModel): name: str age: int - + zyx.extract(User, "John is 20 years old") ``` @@ -69,7 +60,9 @@ class User(BaseModel): text = [text] if verbose: - logger.info(f"Extracting information from {len(text)} text(s) into {target.__name__} model.") + logger.info( + f"Extracting information from {len(text)} text(s) into {target.__name__} model." + ) logger.info(f"Using model: {model}") logger.info(f"Batch size: {batch_size}") logger.info(f"Process: {process}") @@ -91,24 +84,24 @@ class User(BaseModel): base_url=base_url, organization=organization, provider=client, - verbose=verbose + verbose=verbose, ) results = [] if process == "single": response_model = target - + for i in range(0, len(text), batch_size): - batch = text[i:i+batch_size] + batch = text[i : i + batch_size] user_message = "Extract information from the following text(s) and fit it into the given model:\n\n" for idx, t in enumerate(batch, 1): user_message += f"{idx}. {t}\n\n" - + result = completion_client.completion( messages=[ {"role": "system", "content": system_message}, - {"role": "user", "content": user_message} + {"role": "user", "content": user_message}, ], model=model, response_model=response_model, @@ -116,23 +109,23 @@ class User(BaseModel): max_retries=max_retries, temperature=temperature, ) - + results.append(result) - + return results if len(results) > 1 else results[0] else: # batch process for i in range(0, len(text), batch_size): - batch = text[i:i+batch_size] + batch = text[i : i + batch_size] batch_message = "Extract information from the following texts and fit it into the given model:\n\n" for idx, t in enumerate(batch, 1): batch_message += f"{idx}. {t}\n\n" - + response_model = create_model("ResponseModel", items=(List[target], ...)) - + result = completion_client.completion( messages=[ {"role": "system", "content": system_message}, - {"role": "user", "content": batch_message} + {"role": "user", "content": batch_message}, ], model=model, response_model=response_model, @@ -140,22 +133,20 @@ class User(BaseModel): max_retries=max_retries, temperature=temperature, ) - + results.extend(result.items) - + return results + if __name__ == "__main__": + class User(BaseModel): name: str age: int - text = [ - "John is 20 years old", - "Alice is 30 years old", - "Bob is 25 years old" - ] + text = ["John is 20 years old", "Alice is 30 years old", "Bob is 25 years old"] results = extract(User, text, process="batch", batch_size=2, verbose=True) for result in results: - print(result) \ No newline at end of file + print(result) diff --git a/zyx/resources/completions/base/function.py b/zyx/resources/completions/base/function.py index fa7a38e..f36d632 100644 --- a/zyx/resources/completions/base/function.py +++ b/zyx/resources/completions/base/function.py @@ -1,15 +1,6 @@ -from ....client import ( - Client, - InstructorMode -) +from ....client import Client, InstructorMode from pydantic import BaseModel, Field -from typing import ( - Callable, - Optional, - Literal, - get_type_hints, - Any -) +from typing import Callable, Optional, Literal, get_type_hints, Any import traceback import logging import importlib @@ -18,12 +9,13 @@ import os import tempfile + class FunctionResponse(BaseModel): code: str output: Any -def prompt_user_library_install(libs : str) -> None: +def prompt_user_library_install(libs: str) -> None: """Prompts user to install the required libraries for the function to run, installs if user enters y""" import subprocess @@ -32,8 +24,10 @@ def prompt_user_library_install(libs : str) -> None: print(f"The function requires the following libraries to run: {libs}") install_prompt = input("Do you want to install these libraries? (y/n): ") - if install_prompt.lower() == 'y': - subprocess.check_call([sys.executable, "-m", "pip", "install", *libs.split(',')]) + if install_prompt.lower() == "y": + subprocess.check_call( + [sys.executable, "-m", "pip", "install", *libs.split(",")] + ) print("Libraries installed successfully.") @@ -91,7 +85,8 @@ class CodeGenerationModel(BaseModel): ..., description="Complete Python code as a single string" ) explanation: Optional[str] = Field( - None, description="An optional explanation for the code. Not required, but any comments should go here." + None, + description="An optional explanation for the code. Not required, but any comments should go here.", ) error_context = "" @@ -135,7 +130,7 @@ class CodeGenerationModel(BaseModel): api_key=api_key, base_url=base_url, provider=client, - verbose=verbose + verbose=verbose, ) response = completion_client.completion( @@ -169,26 +164,38 @@ class CodeGenerationModel(BaseModel): # Execute the generated code in a local namespace local_namespace = {} exec_globals = globals().copy() - + # Dynamically import required modules - for line in full_code.split('\n'): - if line.startswith('import ') or line.startswith('from '): + for line in full_code.split("\n"): + if line.startswith("import ") or line.startswith("from "): try: exec(line, exec_globals) except ImportError as e: print(f"Failed to import: {line}. Error: {str(e)}") - print("Attempting to install the required package...") - package = line.split()[1].split('.')[0] - subprocess.check_call([sys.executable, "-m", "pip", "install", package]) + print( + "Attempting to install the required package..." + ) + package = line.split()[1].split(".")[0] + subprocess.check_call( + [ + sys.executable, + "-m", + "pip", + "install", + package, + ] + ) exec(line, exec_globals) - + exec(full_code, exec_globals, local_namespace) - if 'result' not in local_namespace: - raise ValueError("No result object found in the generated code.") - - result = local_namespace['result'] - + if "result" not in local_namespace: + raise ValueError( + "No result object found in the generated code." + ) + + result = local_namespace["result"] + if verbose: print(f"Result type: {type(result)}") if isinstance(result, logging.Logger): @@ -199,7 +206,7 @@ class CodeGenerationModel(BaseModel): if return_code: return FunctionResponse(code=full_code, output=result) return result - + finally: # Clean up: remove the temporary file os.unlink(temp_file_path) @@ -209,9 +216,7 @@ class CodeGenerationModel(BaseModel): print(f"Import error: {str(e)}") print(f"Traceback: {traceback.format_exc()}") prompt_user_library_install(e.name) - raise RuntimeError( - f"Import error: {str(e)}" - ) + raise RuntimeError(f"Import error: {str(e)}") except Exception as e: print(f"Error in code generation or execution: {str(e)}") @@ -241,10 +246,7 @@ class CodeGenerationModel(BaseModel): ] completion_client = Client( - api_key=api_key, - base_url=base_url, - provider=client, - verbose=verbose + api_key=api_key, base_url=base_url, provider=client, verbose=verbose ) response = completion_client.completion( @@ -282,4 +284,4 @@ def get_logger(name: str): """ logger = get_logger("my_logger") - logger.info("Hello, world!") \ No newline at end of file + logger.info("Hello, world!") diff --git a/zyx/resources/completions/base/generate.py b/zyx/resources/completions/base/generate.py index a454cdc..ddfdba5 100644 --- a/zyx/resources/completions/base/generate.py +++ b/zyx/resources/completions/base/generate.py @@ -1,38 +1,28 @@ from ....lib.utils.logger import get_logger -from ....client import ( - Client, - InstructorMode -) +from ....client import Client, InstructorMode from pydantic import BaseModel, create_model -from typing import ( - List, - Literal, - Optional, - Type, - Union -) +from typing import List, Literal, Optional, Type, Union logger = get_logger("generate") def generate( - target: Type[BaseModel], - instructions: Optional[str] = None, - process: Literal["single", "batch"] = "single", - n: int = 1, - batch_size: int = 3, - model: str = "gpt-4o-mini", - api_key: Optional[str] = None, - base_url: Optional[str] = None, - organization: Optional[str] = None, - temperature: Optional[float] = None, - max_retries: int = 3, - mode: InstructorMode = "tool_call", - client: Optional[Literal["openai", "litellm"]] = None, - verbose: bool = False + target: Type[BaseModel], + instructions: Optional[str] = None, + process: Literal["single", "batch"] = "single", + n: int = 1, + batch_size: int = 3, + model: str = "gpt-4o-mini", + api_key: Optional[str] = None, + base_url: Optional[str] = None, + organization: Optional[str] = None, + temperature: Optional[float] = None, + max_retries: int = 3, + mode: InstructorMode = "tool_call", + client: Optional[Literal["openai", "litellm"]] = None, + verbose: bool = False, ) -> Union[BaseModel, List[BaseModel]]: - """ Generates a single or batch of pydantic models based on the provided target schema. """ @@ -51,23 +41,31 @@ def generate( Ensure that all generated instances comply with the model's schema and constraints. """ - user_message = instructions if instructions else f"Generate {n} instance(s) of the given model." + user_message = ( + instructions + if instructions + else f"Generate {n} instance(s) of the given model." + ) completion_client = Client( api_key=api_key, base_url=base_url, organization=organization, provider=client, - verbose=verbose + verbose=verbose, ) if process == "single" or n == 1: - response_model = target if n == 1 else create_model("ResponseModel", items=(List[target], ...)) - + response_model = ( + target + if n == 1 + else create_model("ResponseModel", items=(List[target], ...)) + ) + result = completion_client.completion( messages=[ {"role": "system", "content": system_message}, - {"role": "user", "content": user_message} + {"role": "user", "content": user_message}, ], model=model, response_model=response_model, @@ -75,7 +73,7 @@ def generate( max_retries=max_retries, temperature=temperature, ) - + return result if n == 1 else result.items else: # batch process results = [] @@ -84,13 +82,13 @@ def generate( batch_message = f"Generate {batch_n} instances of the given model." if results: batch_message += f"\nPreviously generated instances: {results[-3:]}\nEnsure these new instances are different." - + response_model = create_model("ResponseModel", items=(List[target], ...)) - + result = completion_client.completion( messages=[ {"role": "system", "content": system_message}, - {"role": "user", "content": batch_message} + {"role": "user", "content": batch_message}, ], model=model, response_model=response_model, @@ -98,14 +96,16 @@ def generate( max_retries=max_retries, temperature=temperature, ) - + results.extend(result.items) - + return results + if __name__ == "__main__": + class User(BaseModel): name: str age: int - print(generate(User, n=5, process="batch", batch_size=2, verbose=True)) \ No newline at end of file + print(generate(User, n=5, process="batch", batch_size=2, verbose=True)) diff --git a/zyx/resources/completions/base/system_prompt.py b/zyx/resources/completions/base/system_prompt.py index 2461127..1811c17 100644 --- a/zyx/resources/completions/base/system_prompt.py +++ b/zyx/resources/completions/base/system_prompt.py @@ -2,15 +2,13 @@ from pydantic import BaseModel, Field, create_model from ....lib.utils.logger import get_logger -from ....client import ( - Client, - InstructorMode -) +from ....client import Client, InstructorMode logger = get_logger("system_prompt") PROMPT_TYPES = Literal["costar", "tidd-ec"] + class Prompts: costar = """ ## CONTEXT ## @@ -105,17 +103,21 @@ class TiddECSystemPrompt(BaseModel): description="Provides concrete examples of desired outcomes or responses. This component is invaluable for guiding the LLM towards the expected format, style, or content of the response. Each item must be minimum 20 tokens.", ) + PROMPT_TYPES_MAPPING = {"costar": CostarSystemPrompt, "tidd-ec": TiddECSystemPrompt} + def get_system_prompt(type: PROMPT_TYPES = "costar") -> dict[str, str]: prompt_content = getattr(Prompts, type.replace("-", "_"), None) if prompt_content is None: raise ValueError(f"Invalid prompt type: {type}") return {"role": "system", "content": prompt_content} + def get_response_model(type: PROMPT_TYPES = "costar") -> BaseModel: return PROMPT_TYPES_MAPPING[type] + def system_prompt( instructions: Union[str, List[str]], type: PROMPT_TYPES = "costar", @@ -130,7 +132,7 @@ def system_prompt( temperature: Optional[float] = None, mode: InstructorMode = "markdown_json_mode", max_retries: int = 3, - max_tokens : Optional[int] = None, + max_tokens: Optional[int] = None, client: Optional[Literal["openai", "litellm"]] = None, response_format: Union[Literal["pydantic"], Literal["dict"], None] = None, verbose: bool = False, @@ -147,7 +149,7 @@ def system_prompt( base_url=base_url, organization=organization, provider=client, - verbose=verbose + verbose=verbose, ) response_model = get_response_model(type=type) @@ -157,7 +159,10 @@ def system_prompt( instructions = [instructions] if optimize: - instructions = [f"Optimize the following system prompt:\n\n{instr}" for instr in instructions] + instructions = [ + f"Optimize the following system prompt:\n\n{instr}" + for instr in instructions + ] results = [] @@ -165,10 +170,18 @@ def system_prompt( for instr in instructions: messages = [ system_prompt, - {"role": "user", "content": f"Generate a system prompt for the following instructions:\n\nINSTRUCTIONS:\n{instr}"} + { + "role": "user", + "content": f"Generate a system prompt for the following instructions:\n\nINSTRUCTIONS:\n{instr}", + }, ] if results: - messages.append({"role": "assistant", "content": f"Previously generated prompts:\n{results[-1]}"}) + messages.append( + { + "role": "assistant", + "content": f"Previously generated prompts:\n{results[-1]}", + } + ) result = completion_client.completion( messages=messages, @@ -182,18 +195,19 @@ def system_prompt( results.append(result) else: # batch process for i in range(0, len(instructions), batch_size): - batch = instructions[i:i+batch_size] - batch_message = "Generate system prompts for the following instructions:\n\n" + batch = instructions[i : i + batch_size] + batch_message = ( + "Generate system prompts for the following instructions:\n\n" + ) for idx, instr in enumerate(batch, 1): batch_message += f"{idx}. {instr}\n\n" - response_model_batch = create_model("ResponseModel", items=(List[response_model], ...)) - + response_model_batch = create_model( + "ResponseModel", items=(List[response_model], ...) + ) + result = completion_client.completion( - messages=[ - system_prompt, - {"role": "user", "content": batch_message} - ], + messages=[system_prompt, {"role": "user", "content": batch_message}], model=model, response_model=response_model_batch, mode=mode, @@ -201,7 +215,7 @@ def system_prompt( max_tokens=max_tokens, temperature=temperature, ) - + results.extend(result.items) if response_format == "pydantic": @@ -219,19 +233,22 @@ def system_prompt( response_string.append(f"## {field.capitalize()} ##\n{formatted_value}\n\n") if response_format == "dict": - formatted_results.append({"role": "system", "content": "\n".join(response_string)}) + formatted_results.append( + {"role": "system", "content": "\n".join(response_string)} + ) else: formatted_results.append("\n".join(response_string)) return formatted_results if len(formatted_results) > 1 else formatted_results[0] + if __name__ == "__main__": # Example usage instructions = [ "Create a system prompt for a chatbot that helps users with programming questions.", - "Generate a system prompt for an AI assistant that provides travel recommendations." + "Generate a system prompt for an AI assistant that provides travel recommendations.", ] - + result = system_prompt( instructions=instructions, type="costar", @@ -239,10 +256,10 @@ def system_prompt( process="sequential", n=2, batch_size=2, - verbose=True + verbose=True, ) - + print("Generated System Prompts:") for idx, prompt in enumerate(result, 1): print(f"\nPrompt {idx}:") - print(prompt) \ No newline at end of file + print(prompt) diff --git a/zyx/resources/data/chunk.py b/zyx/resources/data/chunk.py index 24ad9df..9b2a2be 100644 --- a/zyx/resources/data/chunk.py +++ b/zyx/resources/data/chunk.py @@ -34,7 +34,7 @@ def chunk( max_token_chars: int: The maximum number of characters to use for chunking. Returns: - Union[List[str], List[List[str]]]: The chunked content. + Union[List[str], List[List[str]]]: The chunked content. """ try: tokenizer = tiktoken.encoding_for_model(model) diff --git a/zyx/resources/data/reader.py b/zyx/resources/data/reader.py index 825b461..8ff8e81 100644 --- a/zyx/resources/data/reader.py +++ b/zyx/resources/data/reader.py @@ -56,7 +56,7 @@ def read( paths = [_download_if_url(p) for p in path] else: paths = [_download_if_url(path)] - + paths = [Path(p) for p in paths] try: diff --git a/zyx/resources/ext/app.py b/zyx/resources/ext/app.py index af46410..e2474af 100644 --- a/zyx/resources/ext/app.py +++ b/zyx/resources/ext/app.py @@ -6,11 +6,7 @@ from ...lib.utils.logger import get_logger -from ...client import ( - Client, - completion, - InstructorMode -) +from ...client import Client, completion, InstructorMode logger = get_logger("app") @@ -140,6 +136,7 @@ "ivory", ] + class ZyxApp(App): CSS = """ Screen { @@ -333,8 +330,12 @@ def clear_messages(self): def save_params(self): self.params["model"] = self.query_one("#model_input").value or self.model - self.params["max_tokens"] = int(self.query_one("#max_tokens_input").value or 0) or None - self.params["temperature"] = float(self.query_one("#temperature_input").value or 0) or None + self.params["max_tokens"] = ( + int(self.query_one("#max_tokens_input").value or 0) or None + ) + self.params["temperature"] = ( + float(self.query_one("#temperature_input").value or 0) or None + ) self.params["instruction"] = self.query_one("#instruction_input").value # Update the class attributes @@ -413,6 +414,7 @@ def load_params(self): self.query_one("#temperature_input").value = str(self.params["temperature"]) self.query_one("#instruction_input").value = self.params["instruction"] + def terminal( messages: Union[str, list[dict]] = None, model: Optional[str] = "gpt-4o-mini", @@ -463,5 +465,6 @@ def terminal( except Exception as e: print(f"Error running ZyxApp: {e}") + if __name__ == "__main__": - terminal() \ No newline at end of file + terminal() diff --git a/zyx/resources/stores/memory.py b/zyx/resources/stores/memory.py index cec23c4..73b34e9 100644 --- a/zyx/resources/stores/memory.py +++ b/zyx/resources/stores/memory.py @@ -8,10 +8,7 @@ from ...lib.types.document import Document from ...lib.utils.logger import get_logger -from ...client import ( - completion, - InstructorMode -) +from ...client import completion, InstructorMode from ..data.chunk import chunk from ..completions.base.generate import generate @@ -53,7 +50,6 @@ def get_embedding_from_api(self, text: str) -> List[float]: class Memory: - """ Class for storing and retrieving data using Chroma. """ @@ -110,14 +106,18 @@ def _create_or_get_collection(self): embedding_fn = CustomEmbeddingFunction(api_key=self.embedding_api_key) if self.collection_name in self.client.list_collections(): logger.info(f"Collection '{self.collection_name}' already exists.") - return self.client.get_collection(self.collection_name, embedding_function=embedding_fn) + return self.client.get_collection( + self.collection_name, embedding_function=embedding_fn + ) else: logger.info(f"Creating collection '{self.collection_name}'.") - return self.client.create_collection(name=self.collection_name, embedding_function=embedding_fn) + return self.client.create_collection( + name=self.collection_name, embedding_function=embedding_fn + ) def _get_embedding(self, text: str) -> List[float]: """Generate embeddings for a given text using the custom embedding function. - + Args: text (str): The text to generate an embedding for. @@ -133,7 +133,7 @@ def add( metadata: Optional[dict] = None, ): """Add documents or data to Chroma. - + Args: data (Union[str, List[str], Document, List[Document]]): The data to add to Chroma. metadata (Optional[dict]): The metadata to add to the data. @@ -155,14 +155,14 @@ def add( # Chunk the content chunks = chunk(text, chunk_size=self.chunk_size, model=self.model) - + for chunk_text in chunks: embedding_vector = self._get_embedding(chunk_text) ids.append(str(uuid.uuid4())) embeddings.append(embedding_vector) texts.append(chunk_text) chunk_metadata = metadata.copy() if metadata else {} - chunk_metadata['chunk'] = True + chunk_metadata["chunk"] = True metadatas.append(chunk_metadata) except Exception as e: logger.error(f"Error processing item: {item}. Error: {e}") @@ -174,7 +174,9 @@ def add( self.collection.add( ids=ids, embeddings=embeddings, metadatas=metadatas, documents=texts ) - logger.info(f"Successfully added {len(embeddings)} chunks to the collection.") + logger.info( + f"Successfully added {len(embeddings)} chunks to the collection." + ) except Exception as e: logger.error(f"Error adding points to collection: {e}") else: @@ -182,7 +184,7 @@ def add( def search(self, query: str, top_k: int = 5) -> SearchResponse: """Search in Chroma collection. - + Args: query (str): The query to search for. top_k (int): The number of results to return. @@ -192,7 +194,9 @@ def search(self, query: str, top_k: int = 5) -> SearchResponse: """ try: query_embedding = self._get_embedding(query) - search_results = self.collection.query(query_embeddings=[query_embedding], n_results=top_k) + search_results = self.collection.query( + query_embeddings=[query_embedding], n_results=top_k + ) nodes = [] for i in range(len(search_results["ids"][0])): # Note the [0] here @@ -200,7 +204,9 @@ def search(self, query: str, top_k: int = 5) -> SearchResponse: id=search_results["ids"][0][i], text=search_results["documents"][0][i], embedding=query_embedding, - metadata=search_results["metadatas"][0][i] if search_results["metadatas"] else {} + metadata=search_results["metadatas"][0][i] + if search_results["metadatas"] + else {}, ) nodes.append(node) return SearchResponse(query=query, results=nodes) @@ -210,13 +216,14 @@ def search(self, query: str, top_k: int = 5) -> SearchResponse: def _summarize_results(self, results: List[ChromaNode]) -> str: """Summarize the search results. - + Args: results (List[ChromaNode]): The search results. Returns: str: The summary of the search results. """ + class SummaryModel(BaseModel): summary: str @@ -227,7 +234,7 @@ class SummaryModel(BaseModel): SummaryModel, instructions="Provide a concise summary of the following text, focusing on the most important information:", model=self.model, - n=1 + n=1, ) return summary.summary @@ -251,7 +258,7 @@ def completion( verbose: Optional[bool] = False, ): """Perform completion with context from Chroma. - + Args: messages (Union[str, List[dict]]): The messages to use for the completion. model (Optional[str]): The model to use for the completion. @@ -274,7 +281,10 @@ def completion( if isinstance(messages, str): messages = [{"role": "user", "content": messages}] elif isinstance(messages, list): - messages = [{"role": "user", "content": m} if isinstance(m, str) else m for m in messages] + messages = [ + {"role": "user", "content": m} if isinstance(m, str) else m + for m in messages + ] query = messages[-1].get("content", "") if messages else "" @@ -295,7 +305,9 @@ def completion( else: for message in messages: if message.get("role", "") == "system": - message["content"] += f"\nAdditional context: {summarized_results}" + message["content"] += ( + f"\nAdditional context: {summarized_results}" + ) try: result = completion( @@ -326,14 +338,22 @@ def completion( if __name__ == "__main__": try: # Initialize the Store - store = Memory(collection_name="test_collection", embedding_api_key="your-api-key") - + store = Memory( + collection_name="test_collection", embedding_api_key="your-api-key" + ) + # Test adding single string store.add("This is a single string test.") print("Added single string.") # Test adding list of strings - store.add(["Multiple string test 1", "Multiple string test 2", "Multiple string test 3"]) + store.add( + [ + "Multiple string test 1", + "Multiple string test 2", + "Multiple string test 3", + ] + ) print("Added multiple strings.") # Test adding Document @@ -372,10 +392,10 @@ def completion( messages=[{"role": "user", "content": "Summarize the documents."}], model="gpt-3.5-turbo", temperature=0.7, - max_tokens=150 + max_tokens=150, ) print("\nCustom completion result:") print(custom_completion) except Exception as e: - logger.error(f"Error in main execution: {e}") \ No newline at end of file + logger.error(f"Error in main execution: {e}")