From b8d27bce876bc35834bb4f971af6f066eca1ac47 Mon Sep 17 00:00:00 2001 From: monoxgas Date: Sun, 17 Mar 2024 18:09:42 -0400 Subject: [PATCH] 0.1.7 - Improve nested models and attr/element support --- README.md | 26 ++++++++++++++++++++++++++ pyproject.toml | 2 +- rigging/__init__.py | 4 ++-- rigging/model.py | 8 ++++++-- test.py | 33 --------------------------------- tests/test_xml_parsing.py | 34 +++++++++++++++++++++++++++++++++- 6 files changed, 68 insertions(+), 39 deletions(-) delete mode 100644 test.py diff --git a/README.md b/README.md index e658e1a..5656903 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,32 @@ print(f"{chat.last!r}") # Message(role='assistant', parts=[], content='new content') ``` +### Complex Models + +```python +import rigging as rg + +class Inner(rg.Model): + type: str = rg.attr() + content: str + +class Outer(rg.Model): + name: str = rg.attr() + inners: list[Inner] = rg.element() + +outer = Outer(name="foo", inners=[ + Inner(type="cat", content="meow"), + Inner(type="dog", content="bark") +]) + +print(outer.to_pretty_xml()) + +# +# meow +# bark +# +``` + ### Tools ```python diff --git a/pyproject.toml b/pyproject.toml index 34a4097..2b2c1e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "rigging" -version = "0.1.6" +version = "0.1.7" description = "LLM Interaction Framework" authors = ["Nick Landers "] license = "MIT" diff --git a/rigging/__init__.py b/rigging/__init__.py index 445b18c..ce29353 100644 --- a/rigging/__init__.py +++ b/rigging/__init__.py @@ -1,9 +1,9 @@ from rigging.generator import get_generator from rigging.message import Message, MessageDict, Messages -from rigging.model import Model +from rigging.model import Model, attr, element from rigging.tool import Tool -__all__ = ["get_generator", "Message", "MessageDict", "Messages", "Tool", "Model"] +__all__ = ["get_generator", "Message", "MessageDict", "Messages", "Tool", "Model", "attr", "element"] from loguru import logger diff --git a/rigging/model.py b/rigging/model.py index a31179b..05de4bb 100644 --- a/rigging/model.py +++ b/rigging/model.py @@ -5,6 +5,8 @@ from pydantic import ValidationError, field_validator from pydantic.alias_generators import to_snake from pydantic_xml import BaseXmlModel +from pydantic_xml import attr as attr +from pydantic_xml import element as element from pydantic_xml.element import SearchMode # type: ignore [attr-defined] from pydantic_xml.typedefs import NsMap @@ -21,6 +23,8 @@ # content to be escaped. We should probably just write something # custom for our use case that supports JSON, YAML, and XML +BASIC_TYPES = [int, str, float, bool] + class XmlTagDescriptor: def __get__(self, _: t.Any, owner: t.Any) -> str: @@ -68,7 +72,7 @@ def to_pretty_xml(self) -> str: @classmethod def is_simple(cls) -> bool: field_values = list(cls.model_fields.values()) - return len(field_values) == 1 + return len(field_values) == 1 and field_values[0].annotation in BASIC_TYPES @classmethod def xml_start_tag(cls) -> str: @@ -100,7 +104,7 @@ def xml_example(cls) -> str: @classmethod def extract_xml(cls, content: str) -> tuple[ModelGeneric, str]: - pattern = r"(<([\w-]+)>((.*?)))" + pattern = r"(<([\w-]+).*?>((.*?)))" matches = re.findall(pattern, content, flags=re.DOTALL) matches_with_tag = [m for m in matches if m[1] == cls.__xml_tag__] diff --git a/test.py b/test.py deleted file mode 100644 index 2e0403a..0000000 --- a/test.py +++ /dev/null @@ -1,33 +0,0 @@ -import rigging as rg - - -class Reasoning(rg.Model): - content: str - - -generator = rg.get_generator("claude-2.1") - -meaning = generator.chat( - [ - { - "role": "user", - "content": "What is the meaning of life in one sentence? " - f"Document your reasoning between {Reasoning.xml_tags()} tags.", - }, - ] -).run() - -# Gracefully attempt to parse and deal -# with missing models as None - -reasoning = meaning.last.try_parse(Reasoning) -if reasoning: - print("reasoning:", reasoning.content.strip()) - -# Strip parsed content to avoid sharing -# previous thoughts with the model. - -without_reasons = meaning.strip(Reasoning) -print("meaning of life:", without_reasons.last.content.strip()) - -# follow_up = without_thoughts.continue_(...) diff --git a/tests/test_xml_parsing.py b/tests/test_xml_parsing.py index 6e0d93c..6937e0b 100644 --- a/tests/test_xml_parsing.py +++ b/tests/test_xml_parsing.py @@ -3,7 +3,31 @@ import pytest -from rigging.model import Answer, CommaDelimitedAnswer, DelimitedAnswer, Model, Question, QuestionAnswer, YesNoAnswer +from rigging.model import ( + Answer, + CommaDelimitedAnswer, + DelimitedAnswer, + Model, + Question, + QuestionAnswer, + YesNoAnswer, + attr, + element, +) + + +class NameWithThings(Model): + name: str = attr() + things: list[str] = element("thing") + + +class Inner(Model): + type: str = attr() + content: str + + +class Wrapped(Model): + inners: list[Inner] = element() @pytest.mark.parametrize( @@ -61,6 +85,14 @@ "hello / world / foo / bar, test | value", [(DelimitedAnswer, ["hello", "world", "foo", "bar, test | value"])], ), + pytest.param( + 'ab', + [(NameWithThings, NameWithThings(name="test", things=["a", "b"]))], + ), + pytest.param( + 'meowbark', + [(Wrapped, Wrapped(inners=[Inner(type="cat", content="meow"), Inner(type="dog", content="bark")]))], + ), ], ) def test_xml_parsing(content: str, expectations: list[tuple[Model, str]]) -> None: