Skip to content

Commit

Permalink
contex storage class splitted
Browse files Browse the repository at this point in the history
  • Loading branch information
pseusys committed Nov 22, 2024
1 parent 9e7cf47 commit c34f8e7
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 36 deletions.
26 changes: 13 additions & 13 deletions chatsky/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from chatsky.core.message import Message
from chatsky.slots.slots import SlotManager
from chatsky.core.node_label import AbsoluteNodeLabel
from chatsky.core.ctx_dict import ContextDict
from chatsky.core.ctx_dict import ContextDict, LabelContextDict, MessageContextDict

if TYPE_CHECKING:
from chatsky.core.service import ComponentExecutionState
Expand Down Expand Up @@ -102,9 +102,9 @@ class Context(BaseModel):
It is set (and managed) by :py:class:`~chatsky.context_storages.DBContextStorage`.
"""
current_turn_id: int = Field(default=0)
labels: ContextDict[int, AbsoluteNodeLabel] = Field(default_factory=lambda: ContextDict.empty(AbsoluteNodeLabel))
requests: ContextDict[int, Message] = Field(default_factory=lambda: ContextDict.empty(Message))
responses: ContextDict[int, Message] = Field(default_factory=lambda: ContextDict.empty(Message))
labels: LabelContextDict = Field(default_factory=LabelContextDict)
requests: MessageContextDict = Field(default_factory=MessageContextDict)
responses: MessageContextDict = Field(default_factory=MessageContextDict)
"""
`turns` stores the history of all passed `labels`, `requests`, and `responses`.
Expand Down Expand Up @@ -139,9 +139,9 @@ async def connected(
uid = str(uuid4())
logger.debug(f"Disconnected context created with uid: {uid}")
instance = cls(id=uid)
instance.requests = await ContextDict.new(storage, uid, storage._requests_field_name, Message)
instance.responses = await ContextDict.new(storage, uid, storage._responses_field_name, Message)
instance.labels = await ContextDict.new(storage, uid, storage._labels_field_name, AbsoluteNodeLabel)
instance.requests = await MessageContextDict.new(storage, uid, storage._requests_field_name)
instance.responses = await MessageContextDict.new(storage, uid, storage._responses_field_name)
instance.labels = await LabelContextDict.new(storage, uid, storage._labels_field_name)
await instance.labels.update({0: start_label})
instance._storage = storage
return instance
Expand All @@ -152,9 +152,9 @@ async def connected(
logger.debug(f"Connected context created with uid: {id}")
main, labels, requests, responses = await gather(
storage.load_main_info(id),
ContextDict.connected(storage, id, storage._labels_field_name, AbsoluteNodeLabel),
ContextDict.connected(storage, id, storage._requests_field_name, Message),
ContextDict.connected(storage, id, storage._responses_field_name, Message),
LabelContextDict.connected(storage, id, storage._labels_field_name),
MessageContextDict.connected(storage, id, storage._requests_field_name),
MessageContextDict.connected(storage, id, storage._responses_field_name),
)
if main is None:
crt_at = upd_at = time_ns()
Expand Down Expand Up @@ -250,17 +250,17 @@ def _validate_model(value: Any, handler: Callable[[Any], "Context"], _) -> "Cont
labels_obj = value.get("labels", dict())
if isinstance(labels_obj, Dict):
labels_obj = TypeAdapter(Dict[int, AbsoluteNodeLabel]).validate_python(labels_obj)
instance.labels = ContextDict.model_validate(labels_obj)
instance.labels = LabelContextDict.model_validate(labels_obj)
instance.labels._ctx_id = instance.id
requests_obj = value.get("requests", dict())
if isinstance(requests_obj, Dict):
requests_obj = TypeAdapter(Dict[int, Message]).validate_python(requests_obj)
instance.requests = ContextDict.model_validate(requests_obj)
instance.requests = MessageContextDict.model_validate(requests_obj)
instance.requests._ctx_id = instance.id
responses_obj = value.get("responses", dict())
if isinstance(responses_obj, Dict):
responses_obj = TypeAdapter(Dict[int, Message]).validate_python(responses_obj)
instance.responses = ContextDict.model_validate(responses_obj)
instance.responses = MessageContextDict.model_validate(responses_obj)
instance.responses._ctx_id = instance.id
return instance
else:
Expand Down
40 changes: 25 additions & 15 deletions chatsky/core/ctx_dict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from abc import abstractmethod
from asyncio import gather
from hashlib import sha256
import logging
Expand All @@ -22,6 +23,8 @@

from pydantic import BaseModel, PrivateAttr, TypeAdapter, model_serializer, model_validator

from chatsky.core.message import Message
from chatsky.core.node_label import AbsoluteNodeLabel
from chatsky.utils.logging import collapse_num_list

if TYPE_CHECKING:
Expand All @@ -47,39 +50,34 @@ class ContextDict(BaseModel, Generic[K, V]):
_storage: Optional[DBContextStorage] = PrivateAttr(None)
_ctx_id: str = PrivateAttr(default_factory=str)
_field_name: str = PrivateAttr(default_factory=str)
_value_type: Optional[TypeAdapter[Type[V]]] = PrivateAttr(None)

@classmethod
def empty(cls, value_type: Type[V]) -> "ContextDict":
instance = cls()
instance._value_type = TypeAdapter(value_type)
return instance
@property
@abstractmethod
def _value_type(self) -> TypeAdapter[Type[V]]:
raise NotImplementedError

@classmethod
async def new(cls, storage: DBContextStorage, id: str, field: str, value_type: Type[V]) -> "ContextDict":
instance = cls.empty(value_type)
async def new(cls, storage: DBContextStorage, id: str, field: str) -> "ContextDict":
instance = cls()
logger.debug(f"Disconnected context dict created for id {id} and field name: {field}")
instance._ctx_id = id
instance._field_name = field
instance._storage = storage
return instance

@classmethod
async def connected(cls, storage: DBContextStorage, id: str, field: str, value_type: Type[V]) -> "ContextDict":
val_adapter = TypeAdapter(value_type)
async def connected(cls, storage: DBContextStorage, id: str, field: str) -> "ContextDict":
logger.debug(f"Connected context dict created for {id}, {field}")
keys, items = await gather(storage.load_field_keys(id, field), storage.load_field_latest(id, field))
val_key_items = [(k, v) for k, v in items if v is not None]
hashes = {k: get_hash(v) for k, v in val_key_items}
objected = {k: val_adapter.validate_json(v) for k, v in val_key_items}
instance = cls.model_validate(objected)
logger.debug(f"Context dict for {id}, {field} loaded: {collapse_num_list(keys)}")
instance = cls()
instance._storage = storage
instance._ctx_id = id
instance._field_name = field
instance._value_type = val_adapter
instance._keys = set(keys)
instance._hashes = hashes
instance._items = {k: instance._value_type.validate_json(v) for k, v in val_key_items}
instance._hashes = {k: get_hash(v) for k, v in val_key_items}
return instance

async def _load_items(self, keys: List[K]) -> Dict[K, V]:
Expand Down Expand Up @@ -277,3 +275,15 @@ async def store(self) -> None:
self._hashes[k] = get_hash(self._value_type.dump_json(v))
else:
raise RuntimeError(f"{type(self).__name__} is not attached to any context storage!")


class LabelContextDict(ContextDict[int, AbsoluteNodeLabel]):
@property
def _value_type(self) -> TypeAdapter[Type[AbsoluteNodeLabel]]:
return TypeAdapter(AbsoluteNodeLabel)


class MessageContextDict(ContextDict[int, Message]):
@property
def _value_type(self) -> TypeAdapter[Type[Message]]:
return TypeAdapter(Message)
12 changes: 6 additions & 6 deletions tests/core/test_context_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@

from chatsky.context_storages import MemoryContextStorage
from chatsky.core.message import Message
from chatsky.core.ctx_dict import ContextDict
from chatsky.core.ctx_dict import ContextDict, MessageContextDict


class TestContextDict:
@pytest.fixture(scope="function")
async def empty_dict(self) -> ContextDict:
# Empty (disconnected) context dictionary
return ContextDict.empty(Message)
return MessageContextDict()

@pytest.fixture(scope="function")
async def attached_dict(self) -> ContextDict:
# Attached, but not backed by any data context dictionary
storage = MemoryContextStorage()
return await ContextDict.new(storage, "ID", storage._requests_field_name, Message)
return await MessageContextDict.new(storage, "ID", storage._requests_field_name)

@pytest.fixture(scope="function")
async def prefilled_dict(self) -> ContextDict:
Expand All @@ -28,7 +28,7 @@ async def prefilled_dict(self) -> ContextDict:
(2, Message("text 2", misc={"1": 0, "2": 8}).model_dump_json().encode()),
]
await storage.update_field_items(ctx_id, storage._requests_field_name, requests)
return await ContextDict.connected(storage, ctx_id, storage._requests_field_name, Message)
return await MessageContextDict.connected(storage, ctx_id, storage._requests_field_name)

async def test_creation(
self, empty_dict: ContextDict, attached_dict: ContextDict, prefilled_dict: ContextDict
Expand Down Expand Up @@ -122,11 +122,11 @@ async def test_other_methods(self, prefilled_dict: ContextDict) -> None:

async def test_eq_validate(self, empty_dict: ContextDict) -> None:
# Checking empty dict validation
assert empty_dict == ContextDict.model_validate(dict())
assert empty_dict == MessageContextDict.model_validate(dict())
# Checking non-empty dict validation
empty_dict[0] = Message("msg")
empty_dict._added = set()
assert empty_dict == ContextDict.model_validate({0: Message("msg")})
assert empty_dict == MessageContextDict.model_validate({0: Message("msg")})

async def test_serialize_store(
self, empty_dict: ContextDict, attached_dict: ContextDict, prefilled_dict: ContextDict
Expand Down
4 changes: 2 additions & 2 deletions tests/utils/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ async def test_get_context(context_storage: JSONContextStorage):
context = await get_context(context_storage, 2, (1, 2), (2, 3))
copy_ctx = await Context.connected(context_storage, ("flow", "node"))
await copy_ctx.labels.update({0: ("flow_0", "node_0"), 1: ("flow_1", "node_1")})
await copy_ctx.requests.update({0: Message(misc={"0": "zv"}), 1: Message(misc={"0": "sh"})})
await copy_ctx.responses.update({0: Message(misc={"0": ">e"}), 1: Message(misc={"0": "3 "})})
await copy_ctx.requests.update({0: Message(misc={"0": ">e"}), 1: Message(misc={"0": "zv"})})
await copy_ctx.responses.update({0: Message(misc={"0": "3 "}), 1: Message(misc={"0": "sh"})})
copy_ctx.misc.update({"0": " d]", "1": " (b"})
assert context.model_dump(exclude={"id", "current_turn_id"}) == copy_ctx.model_dump(exclude={"id", "current_turn_id"})

Expand Down

0 comments on commit c34f8e7

Please sign in to comment.