Skip to content

Commit

Permalink
Renamed BaseFilter into BaseHistoryFilter, added API reference
Browse files Browse the repository at this point in the history
  • Loading branch information
NotBioWaste905 committed Nov 26, 2024
1 parent 248d77f commit b5ecc1a
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 16 deletions.
2 changes: 1 addition & 1 deletion chatsky/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from chatsky.llm.filters import BaseFilter, FromModel, IsImportant
from chatsky.llm.filters import BaseHistoryFilter, FromModel, IsImportant
from chatsky.llm.methods import BaseMethod, LogProb, Contains
from chatsky.llm.llm_api import LLM_API
8 changes: 4 additions & 4 deletions chatsky/llm/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
logger = Logger(name=__name__)


class BaseFilter(BaseModel, abc.ABC):
class BaseHistoryFilter(BaseModel, abc.ABC):
"""
Base class for all message history filters.
"""
Expand Down Expand Up @@ -50,7 +50,7 @@ def __call__(self, ctx: Context, request: Message, response: Message, model_name
return []


class MessageFilter(BaseFilter):
class MessageFilter(BaseHistoryFilter):
@abc.abstractmethod
def call(self, ctx, message, model_name):
raise NotImplemented
Expand All @@ -59,7 +59,7 @@ def __call__(self, ctx, request, response, model_name):
return self.call(ctx, request, model_name) + self.call(ctx, response, model_name)


class IsImportant(BaseFilter):
class IsImportant(BaseHistoryFilter):
"""
Filter that checks if the "important" field in a Message.misc is True.
"""
Expand All @@ -72,7 +72,7 @@ def call(self, ctx: Context, request: Message, response: Message, model_name: st
return False


class FromModel(BaseFilter):
class FromModel(BaseHistoryFilter):
"""
Filter that checks if the message was sent by the model.
"""
Expand Down
4 changes: 2 additions & 2 deletions chatsky/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from chatsky.core.context import Context
from chatsky.core.message import Image, Message
from chatsky.llm._langchain_imports import HumanMessage, SystemMessage, AIMessage, check_langchain_available
from chatsky.llm.filters import BaseFilter
from chatsky.llm.filters import BaseHistoryFilter


async def message_to_langchain(
Expand Down Expand Up @@ -45,7 +45,7 @@ async def message_to_langchain(


async def context_to_history(
ctx: Context, length: int, filter_func: BaseFilter, model_name: str, max_size: int
ctx: Context, length: int, filter_func: BaseHistoryFilter, model_name: str, max_size: int
) -> list[HumanMessage | AIMessage | SystemMessage]:
"""
Convert context to list of langchain messages.
Expand Down
4 changes: 2 additions & 2 deletions chatsky/responses/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from chatsky.core.context import Context
from chatsky.llm.utils import message_to_langchain, context_to_history
from chatsky.llm._langchain_imports import SystemMessage, check_langchain_available
from chatsky.llm.filters import BaseFilter
from chatsky.llm.filters import BaseHistoryFilter
from chatsky.core.script_function import BaseResponse, AnyResponse


Expand All @@ -32,7 +32,7 @@ class LLMResponse(BaseResponse):
model_name: str
prompt: AnyResponse = Field(default="", validate_default=True)
history: int = 5
filter_func: BaseFilter = BaseFilter()
filter_func: BaseHistoryFilter = BaseHistoryFilter()
message_schema: Union[None, Type[Message], Type[BaseModel]] = None
max_size: int = 1000

Expand Down
2 changes: 1 addition & 1 deletion docs/source/user_guides/llm_integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -145,5 +145,5 @@ Another way of dealing with unwanted messages is by using filtering functions.
from chatsky.llm import IsImportant
RESPONSE: LLMResponse(model_name="model_name_1", history=15, filter_func=IsImportant)
These functions should be classes inheriting from ``BaseFilter``, having a ``__call__`` function with the following signature:
These functions should be classes inheriting from ``BaseHistoryFilter``, having a ``__call__`` function with the following signature:
``def __call__(self, ctx: Context, request: Message, response: Message, model_name: str) -> bool``
7 changes: 4 additions & 3 deletions tutorials/llm/3_filtering_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@
from chatsky.utils.testing import is_interactive_mode
from chatsky.llm import LLM_API
from chatsky.responses.llm import LLMResponse
from chatsky.llm.filters import BaseFilter
from chatsky.llm.filters import BaseHistoryFilter
from chatsky.core.context import Context



# %%
model = LLM_API(
ChatOllama(model="phi3:instruct", temperature=0),
Expand All @@ -41,11 +40,13 @@
"""
In this example we will use very simple filtering function to
retrieve only the important messages.
If you want to learn more about filters see
[API ref](%doclink(api,llm.filters,BaseHistoryFilter)).
"""


# %%
class FilterImportant(BaseFilter):
class FilterImportant(BaseHistoryFilter):
def __call__(
self,
ctx: Context = None,
Expand Down
5 changes: 3 additions & 2 deletions tutorials/llm/4_structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@
from langchain_core.pydantic_v1 import BaseModel, Field



# %% [markdown]
"""
In this tutorial we will define two models.
"""
# %%
assistant_model = LLM_API(ChatOllama(model="llama3.2:1b", temperature=0))
movie_model = LLM_API(ChatOllama(model="kuqoi/qwen2-tools:latest", temperature=0))
movie_model = LLM_API(
ChatOllama(model="kuqoi/qwen2-tools:latest", temperature=0)
)

# %% [markdown]
"""
Expand Down
1 change: 0 additions & 1 deletion tutorials/llm/5_llm_slots.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from chatsky.slots.llm import LLMSlot, LLMGroupSlot



# %% [markdown]
"""
In this example we define LLM Group Slot with two LLM Slots in it.
Expand Down

0 comments on commit b5ecc1a

Please sign in to comment.