Skip to content

Commit

Permalink
Fixes for message model syncing and more double parse_set calls.
Browse files Browse the repository at this point in the history
Version to 0.2.2
  • Loading branch information
monoxgas committed Apr 25, 2024
1 parent 4a4cb32 commit 87bbdf9
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "rigging"
version = "0.2.1"
version = "0.2.2"
description = "LLM Interaction Framework"
authors = ["Nick Landers <monoxgas@gmail.com>"]
license = "MIT"
Expand Down
32 changes: 29 additions & 3 deletions rigging/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,42 @@ def _add_part(self, part: ParsedMessagePart) -> None:
raise ValueError("Incoming part overlaps with an existing part")
self.parts.append(part)

# Looks more complicated than it is. We just want to clean all the models
# in the message content by re-serializing them. As we do so, we'll need
# to watch for the total size of our message shifting and update the slices
# of the following parts accordingly. In other words, as A expands, B which
# follows will have a new start slice and end slice.
#
# TODO: We should probably just re-trigger parsing for everything
def _sync_parts(self) -> None:
shift = 0
for part in self.parts:
existing = self._content[part.slice_]

# Adjust for any previous shifts
part.slice_ = slice(part.slice_.start + shift, part.slice_.stop + shift)

# Check if the content has changed
xml_content = part.model.to_pretty_xml()
if xml_content == existing:
continue

# Otherwise update content, add to shift, and update this slice
old_length = part.slice_.stop - part.slice_.start
new_length = len(xml_content)

self._content = self._content[: part.slice_.start] + xml_content + self._content[part.slice_.stop :]
part.slice_ = slice(part.slice_.start, part.slice_.start + len(xml_content))
part.slice_ = slice(part.slice_.start, part.slice_.start + new_length)

shift += new_length - old_length

@computed_field # type: ignore[misc]
@property
def content(self) -> str:
self._sync_parts()
# We used to sync the models and content each time it was accessed,
# hence the getter. Now we just return the stored content.
# I'll leave it as is for now in case we want to add any
# logic here in the future.
return self._content

@content.setter
Expand Down Expand Up @@ -152,7 +178,7 @@ def try_parse_many(self, types: t.Sequence[type[ModelT]], fail_on_missing: bool
except MissingModelError as e:
if fail_on_missing:
raise e

self._sync_parts()
return parsed

@classmethod
Expand Down
9 changes: 7 additions & 2 deletions tests/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,13 @@ def test_double_parse() -> None:


def test_double_parse_set() -> None:
msg = Message("user", "<person name='John'>30</person><person name='Jane'>25</person>")
msg = Message(
"user",
"Some test content <anothertag><person name='John'>30</person> More mixed content <person name='omad'>90</person><person name='Jane'>25</person>",
)
existing_len = len(msg.content)
msg.parse_set(Person)
msg.parse_set(Person)

assert len(msg.parts) == 2
assert len(msg.content) != existing_len
assert len(msg.parts) == 3

0 comments on commit 87bbdf9

Please sign in to comment.