From b5ecc1a7aa37c59b44703afe74171b783a5f1102 Mon Sep 17 00:00:00 2001 From: NotBioWaste905 Date: Tue, 26 Nov 2024 16:10:22 +0300 Subject: [PATCH] Renamed BaseFilter into BaseHistoryFilter, added API reference --- chatsky/llm/__init__.py | 2 +- chatsky/llm/filters.py | 8 ++++---- chatsky/llm/utils.py | 4 ++-- chatsky/responses/llm.py | 4 ++-- docs/source/user_guides/llm_integration.rst | 2 +- tutorials/llm/3_filtering_history.py | 7 ++++--- tutorials/llm/4_structured_output.py | 5 +++-- tutorials/llm/5_llm_slots.py | 1 - 8 files changed, 17 insertions(+), 16 deletions(-) diff --git a/chatsky/llm/__init__.py b/chatsky/llm/__init__.py index 0c2a948fa..e31b49187 100644 --- a/chatsky/llm/__init__.py +++ b/chatsky/llm/__init__.py @@ -1,3 +1,3 @@ -from chatsky.llm.filters import BaseFilter, FromModel, IsImportant +from chatsky.llm.filters import BaseHistoryFilter, FromModel, IsImportant from chatsky.llm.methods import BaseMethod, LogProb, Contains from chatsky.llm.llm_api import LLM_API diff --git a/chatsky/llm/filters.py b/chatsky/llm/filters.py index 1a0919a6c..c47107c6e 100644 --- a/chatsky/llm/filters.py +++ b/chatsky/llm/filters.py @@ -16,7 +16,7 @@ logger = Logger(name=__name__) -class BaseFilter(BaseModel, abc.ABC): +class BaseHistoryFilter(BaseModel, abc.ABC): """ Base class for all message history filters. """ @@ -50,7 +50,7 @@ def __call__(self, ctx: Context, request: Message, response: Message, model_name return [] -class MessageFilter(BaseFilter): +class MessageFilter(BaseHistoryFilter): @abc.abstractmethod def call(self, ctx, message, model_name): raise NotImplemented @@ -59,7 +59,7 @@ def __call__(self, ctx, request, response, model_name): return self.call(ctx, request, model_name) + self.call(ctx, response, model_name) -class IsImportant(BaseFilter): +class IsImportant(BaseHistoryFilter): """ Filter that checks if the "important" field in a Message.misc is True. """ @@ -72,7 +72,7 @@ def call(self, ctx: Context, request: Message, response: Message, model_name: st return False -class FromModel(BaseFilter): +class FromModel(BaseHistoryFilter): """ Filter that checks if the message was sent by the model. """ diff --git a/chatsky/llm/utils.py b/chatsky/llm/utils.py index 8f4885d06..f6266a8c9 100644 --- a/chatsky/llm/utils.py +++ b/chatsky/llm/utils.py @@ -10,7 +10,7 @@ from chatsky.core.context import Context from chatsky.core.message import Image, Message from chatsky.llm._langchain_imports import HumanMessage, SystemMessage, AIMessage, check_langchain_available -from chatsky.llm.filters import BaseFilter +from chatsky.llm.filters import BaseHistoryFilter async def message_to_langchain( @@ -45,7 +45,7 @@ async def message_to_langchain( async def context_to_history( - ctx: Context, length: int, filter_func: BaseFilter, model_name: str, max_size: int + ctx: Context, length: int, filter_func: BaseHistoryFilter, model_name: str, max_size: int ) -> list[HumanMessage | AIMessage | SystemMessage]: """ Convert context to list of langchain messages. diff --git a/chatsky/responses/llm.py b/chatsky/responses/llm.py index 8943dcdd4..0f750a1d1 100644 --- a/chatsky/responses/llm.py +++ b/chatsky/responses/llm.py @@ -13,7 +13,7 @@ from chatsky.core.context import Context from chatsky.llm.utils import message_to_langchain, context_to_history from chatsky.llm._langchain_imports import SystemMessage, check_langchain_available -from chatsky.llm.filters import BaseFilter +from chatsky.llm.filters import BaseHistoryFilter from chatsky.core.script_function import BaseResponse, AnyResponse @@ -32,7 +32,7 @@ class LLMResponse(BaseResponse): model_name: str prompt: AnyResponse = Field(default="", validate_default=True) history: int = 5 - filter_func: BaseFilter = BaseFilter() + filter_func: BaseHistoryFilter = BaseHistoryFilter() message_schema: Union[None, Type[Message], Type[BaseModel]] = None max_size: int = 1000 diff --git a/docs/source/user_guides/llm_integration.rst b/docs/source/user_guides/llm_integration.rst index 7341bb2e5..0cf86f5d6 100644 --- a/docs/source/user_guides/llm_integration.rst +++ b/docs/source/user_guides/llm_integration.rst @@ -145,5 +145,5 @@ Another way of dealing with unwanted messages is by using filtering functions. from chatsky.llm import IsImportant RESPONSE: LLMResponse(model_name="model_name_1", history=15, filter_func=IsImportant) -These functions should be classes inheriting from ``BaseFilter``, having a ``__call__`` function with the following signature: +These functions should be classes inheriting from ``BaseHistoryFilter``, having a ``__call__`` function with the following signature: ``def __call__(self, ctx: Context, request: Message, response: Message, model_name: str) -> bool`` diff --git a/tutorials/llm/3_filtering_history.py b/tutorials/llm/3_filtering_history.py index 99d17adc4..baff117f4 100644 --- a/tutorials/llm/3_filtering_history.py +++ b/tutorials/llm/3_filtering_history.py @@ -25,11 +25,10 @@ from chatsky.utils.testing import is_interactive_mode from chatsky.llm import LLM_API from chatsky.responses.llm import LLMResponse -from chatsky.llm.filters import BaseFilter +from chatsky.llm.filters import BaseHistoryFilter from chatsky.core.context import Context - # %% model = LLM_API( ChatOllama(model="phi3:instruct", temperature=0), @@ -41,11 +40,13 @@ """ In this example we will use very simple filtering function to retrieve only the important messages. +If you want to learn more about filters see +[API ref](%doclink(api,llm.filters,BaseHistoryFilter)). """ # %% -class FilterImportant(BaseFilter): +class FilterImportant(BaseHistoryFilter): def __call__( self, ctx: Context = None, diff --git a/tutorials/llm/4_structured_output.py b/tutorials/llm/4_structured_output.py index c2afc241d..b05445bf4 100644 --- a/tutorials/llm/4_structured_output.py +++ b/tutorials/llm/4_structured_output.py @@ -29,14 +29,15 @@ from langchain_core.pydantic_v1 import BaseModel, Field - # %% [markdown] """ In this tutorial we will define two models. """ # %% assistant_model = LLM_API(ChatOllama(model="llama3.2:1b", temperature=0)) -movie_model = LLM_API(ChatOllama(model="kuqoi/qwen2-tools:latest", temperature=0)) +movie_model = LLM_API( + ChatOllama(model="kuqoi/qwen2-tools:latest", temperature=0) +) # %% [markdown] """ diff --git a/tutorials/llm/5_llm_slots.py b/tutorials/llm/5_llm_slots.py index 190459401..9d0020a98 100644 --- a/tutorials/llm/5_llm_slots.py +++ b/tutorials/llm/5_llm_slots.py @@ -33,7 +33,6 @@ from chatsky.slots.llm import LLMSlot, LLMGroupSlot - # %% [markdown] """ In this example we define LLM Group Slot with two LLM Slots in it.