diff --git a/docs/api/prompt.md b/docs/api/prompt.md new file mode 100644 index 0000000..614c644 --- /dev/null +++ b/docs/api/prompt.md @@ -0,0 +1 @@ +::: rigging.prompt \ No newline at end of file diff --git a/docs/topics/prompt-functions.md b/docs/topics/prompt-functions.md new file mode 100644 index 0000000..e69de29 diff --git a/mkdocs.yml b/mkdocs.yml index 57ae1ee..c9a2bdc 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -11,6 +11,7 @@ nav: - Models: topics/models.md - Generators: topics/generators.md - Chats and Messages: topics/chats-and-messages.md + - Prompt Functions: topics/prompt-functions.md - Completions: topics/completions.md - Callbacks and Mapping: topics/callbacks-and-mapping.md - Async and Batching: topics/async-and-batching.md @@ -26,6 +27,7 @@ nav: - rigging.generator: api/generator.md - rigging.model: api/model.md - rigging.message: api/message.md + - rigging.prompt: api/prompt.md - rigging.tool: api/tool.md - rigging.data: api/data.md - rigging.parsing: api/parsing.md diff --git a/poetry.lock b/poetry.lock index ca01952..b3dee5d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -634,43 +634,43 @@ toml = ["tomli"] [[package]] name = "cryptography" -version = "42.0.7" +version = "42.0.8" description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." optional = true python-versions = ">=3.7" files = [ - {file = "cryptography-42.0.7-cp37-abi3-macosx_10_12_universal2.whl", hash = "sha256:a987f840718078212fdf4504d0fd4c6effe34a7e4740378e59d47696e8dfb477"}, - {file = "cryptography-42.0.7-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:bd13b5e9b543532453de08bcdc3cc7cebec6f9883e886fd20a92f26940fd3e7a"}, - {file = "cryptography-42.0.7-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a79165431551042cc9d1d90e6145d5d0d3ab0f2d66326c201d9b0e7f5bf43604"}, - {file = "cryptography-42.0.7-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a47787a5e3649008a1102d3df55424e86606c9bae6fb77ac59afe06d234605f8"}, - {file = "cryptography-42.0.7-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:02c0eee2d7133bdbbc5e24441258d5d2244beb31da5ed19fbb80315f4bbbff55"}, - {file = "cryptography-42.0.7-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:5e44507bf8d14b36b8389b226665d597bc0f18ea035d75b4e53c7b1ea84583cc"}, - {file = "cryptography-42.0.7-cp37-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:7f8b25fa616d8b846aef64b15c606bb0828dbc35faf90566eb139aa9cff67af2"}, - {file = "cryptography-42.0.7-cp37-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:93a3209f6bb2b33e725ed08ee0991b92976dfdcf4e8b38646540674fc7508e13"}, - {file = "cryptography-42.0.7-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:e6b8f1881dac458c34778d0a424ae5769de30544fc678eac51c1c8bb2183e9da"}, - {file = "cryptography-42.0.7-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:3de9a45d3b2b7d8088c3fbf1ed4395dfeff79d07842217b38df14ef09ce1d8d7"}, - {file = "cryptography-42.0.7-cp37-abi3-win32.whl", hash = "sha256:789caea816c6704f63f6241a519bfa347f72fbd67ba28d04636b7c6b7da94b0b"}, - {file = "cryptography-42.0.7-cp37-abi3-win_amd64.whl", hash = "sha256:8cb8ce7c3347fcf9446f201dc30e2d5a3c898d009126010cbd1f443f28b52678"}, - {file = "cryptography-42.0.7-cp39-abi3-macosx_10_12_universal2.whl", hash = "sha256:a3a5ac8b56fe37f3125e5b72b61dcde43283e5370827f5233893d461b7360cd4"}, - {file = "cryptography-42.0.7-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:779245e13b9a6638df14641d029add5dc17edbef6ec915688f3acb9e720a5858"}, - {file = "cryptography-42.0.7-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d563795db98b4cd57742a78a288cdbdc9daedac29f2239793071fe114f13785"}, - {file = "cryptography-42.0.7-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:31adb7d06fe4383226c3e963471f6837742889b3c4caa55aac20ad951bc8ffda"}, - {file = "cryptography-42.0.7-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:efd0bf5205240182e0f13bcaea41be4fdf5c22c5129fc7ced4a0282ac86998c9"}, - {file = "cryptography-42.0.7-cp39-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:a9bc127cdc4ecf87a5ea22a2556cab6c7eda2923f84e4f3cc588e8470ce4e42e"}, - {file = "cryptography-42.0.7-cp39-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:3577d029bc3f4827dd5bf8bf7710cac13527b470bbf1820a3f394adb38ed7d5f"}, - {file = "cryptography-42.0.7-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:2e47577f9b18723fa294b0ea9a17d5e53a227867a0a4904a1a076d1646d45ca1"}, - {file = "cryptography-42.0.7-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:1a58839984d9cb34c855197043eaae2c187d930ca6d644612843b4fe8513c886"}, - {file = "cryptography-42.0.7-cp39-abi3-win32.whl", hash = "sha256:e6b79d0adb01aae87e8a44c2b64bc3f3fe59515280e00fb6d57a7267a2583cda"}, - {file = "cryptography-42.0.7-cp39-abi3-win_amd64.whl", hash = "sha256:16268d46086bb8ad5bf0a2b5544d8a9ed87a0e33f5e77dd3c3301e63d941a83b"}, - {file = "cryptography-42.0.7-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:2954fccea107026512b15afb4aa664a5640cd0af630e2ee3962f2602693f0c82"}, - {file = "cryptography-42.0.7-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:362e7197754c231797ec45ee081f3088a27a47c6c01eff2ac83f60f85a50fe60"}, - {file = "cryptography-42.0.7-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:4f698edacf9c9e0371112792558d2f705b5645076cc0aaae02f816a0171770fd"}, - {file = "cryptography-42.0.7-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:5482e789294854c28237bba77c4c83be698be740e31a3ae5e879ee5444166582"}, - {file = "cryptography-42.0.7-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:e9b2a6309f14c0497f348d08a065d52f3020656f675819fc405fb63bbcd26562"}, - {file = "cryptography-42.0.7-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:d8e3098721b84392ee45af2dd554c947c32cc52f862b6a3ae982dbb90f577f14"}, - {file = "cryptography-42.0.7-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:c65f96dad14f8528a447414125e1fc8feb2ad5a272b8f68477abbcc1ea7d94b9"}, - {file = "cryptography-42.0.7-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:36017400817987670037fbb0324d71489b6ead6231c9604f8fc1f7d008087c68"}, - {file = "cryptography-42.0.7.tar.gz", hash = "sha256:ecbfbc00bf55888edda9868a4cf927205de8499e7fabe6c050322298382953f2"}, + {file = "cryptography-42.0.8-cp37-abi3-macosx_10_12_universal2.whl", hash = "sha256:81d8a521705787afe7a18d5bfb47ea9d9cc068206270aad0b96a725022e18d2e"}, + {file = "cryptography-42.0.8-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:961e61cefdcb06e0c6d7e3a1b22ebe8b996eb2bf50614e89384be54c48c6b63d"}, + {file = "cryptography-42.0.8-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e3ec3672626e1b9e55afd0df6d774ff0e953452886e06e0f1eb7eb0c832e8902"}, + {file = "cryptography-42.0.8-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e599b53fd95357d92304510fb7bda8523ed1f79ca98dce2f43c115950aa78801"}, + {file = "cryptography-42.0.8-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:5226d5d21ab681f432a9c1cf8b658c0cb02533eece706b155e5fbd8a0cdd3949"}, + {file = "cryptography-42.0.8-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:6b7c4f03ce01afd3b76cf69a5455caa9cfa3de8c8f493e0d3ab7d20611c8dae9"}, + {file = "cryptography-42.0.8-cp37-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:2346b911eb349ab547076f47f2e035fc8ff2c02380a7cbbf8d87114fa0f1c583"}, + {file = "cryptography-42.0.8-cp37-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:ad803773e9df0b92e0a817d22fd8a3675493f690b96130a5e24f1b8fabbea9c7"}, + {file = "cryptography-42.0.8-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:2f66d9cd9147ee495a8374a45ca445819f8929a3efcd2e3df6428e46c3cbb10b"}, + {file = "cryptography-42.0.8-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:d45b940883a03e19e944456a558b67a41160e367a719833c53de6911cabba2b7"}, + {file = "cryptography-42.0.8-cp37-abi3-win32.whl", hash = "sha256:a0c5b2b0585b6af82d7e385f55a8bc568abff8923af147ee3c07bd8b42cda8b2"}, + {file = "cryptography-42.0.8-cp37-abi3-win_amd64.whl", hash = "sha256:57080dee41209e556a9a4ce60d229244f7a66ef52750f813bfbe18959770cfba"}, + {file = "cryptography-42.0.8-cp39-abi3-macosx_10_12_universal2.whl", hash = "sha256:dea567d1b0e8bc5764b9443858b673b734100c2871dc93163f58c46a97a83d28"}, + {file = "cryptography-42.0.8-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c4783183f7cb757b73b2ae9aed6599b96338eb957233c58ca8f49a49cc32fd5e"}, + {file = "cryptography-42.0.8-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0608251135d0e03111152e41f0cc2392d1e74e35703960d4190b2e0f4ca9c70"}, + {file = "cryptography-42.0.8-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:dc0fdf6787f37b1c6b08e6dfc892d9d068b5bdb671198c72072828b80bd5fe4c"}, + {file = "cryptography-42.0.8-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:9c0c1716c8447ee7dbf08d6db2e5c41c688544c61074b54fc4564196f55c25a7"}, + {file = "cryptography-42.0.8-cp39-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:fff12c88a672ab9c9c1cf7b0c80e3ad9e2ebd9d828d955c126be4fd3e5578c9e"}, + {file = "cryptography-42.0.8-cp39-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:cafb92b2bc622cd1aa6a1dce4b93307792633f4c5fe1f46c6b97cf67073ec961"}, + {file = "cryptography-42.0.8-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:31f721658a29331f895a5a54e7e82075554ccfb8b163a18719d342f5ffe5ecb1"}, + {file = "cryptography-42.0.8-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b297f90c5723d04bcc8265fc2a0f86d4ea2e0f7ab4b6994459548d3a6b992a14"}, + {file = "cryptography-42.0.8-cp39-abi3-win32.whl", hash = "sha256:2f88d197e66c65be5e42cd72e5c18afbfae3f741742070e3019ac8f4ac57262c"}, + {file = "cryptography-42.0.8-cp39-abi3-win_amd64.whl", hash = "sha256:fa76fbb7596cc5839320000cdd5d0955313696d9511debab7ee7278fc8b5c84a"}, + {file = "cryptography-42.0.8-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:ba4f0a211697362e89ad822e667d8d340b4d8d55fae72cdd619389fb5912eefe"}, + {file = "cryptography-42.0.8-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:81884c4d096c272f00aeb1f11cf62ccd39763581645b0812e99a91505fa48e0c"}, + {file = "cryptography-42.0.8-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:c9bb2ae11bfbab395bdd072985abde58ea9860ed84e59dbc0463a5d0159f5b71"}, + {file = "cryptography-42.0.8-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:7016f837e15b0a1c119d27ecd89b3515f01f90a8615ed5e9427e30d9cdbfed3d"}, + {file = "cryptography-42.0.8-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5a94eccb2a81a309806027e1670a358b99b8fe8bfe9f8d329f27d72c094dde8c"}, + {file = "cryptography-42.0.8-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:dec9b018df185f08483f294cae6ccac29e7a6e0678996587363dc352dc65c842"}, + {file = "cryptography-42.0.8-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:343728aac38decfdeecf55ecab3264b015be68fc2816ca800db649607aeee648"}, + {file = "cryptography-42.0.8-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:013629ae70b40af70c9a7a5db40abe5d9054e6f4380e50ce769947b73bf3caad"}, + {file = "cryptography-42.0.8.tar.gz", hash = "sha256:8d09d05439ce7baa8e9e95b07ec5b6c886f548deb7e0f69ef25f64b3bce842f2"}, ] [package.dependencies] @@ -1217,13 +1217,13 @@ socks = ["socksio (==1.*)"] [[package]] name = "huggingface-hub" -version = "0.23.2" +version = "0.23.3" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false python-versions = ">=3.8.0" files = [ - {file = "huggingface_hub-0.23.2-py3-none-any.whl", hash = "sha256:48727a16e704d409c4bb5913613308499664f22a99743435dc3a13b23c485827"}, - {file = "huggingface_hub-0.23.2.tar.gz", hash = "sha256:f6829b62d5fdecb452a76fdbec620cba4c1573655a8d710c1df71735fd9edbd2"}, + {file = "huggingface_hub-0.23.3-py3-none-any.whl", hash = "sha256:22222c41223f1b7c209ae5511d2d82907325a0e3cdbce5f66949d43c598ff3bc"}, + {file = "huggingface_hub-0.23.3.tar.gz", hash = "sha256:1a1118a0b3dea3bab6c325d71be16f5ffe441d32f3ac7c348d6875911b694b5b"}, ] [package.dependencies] @@ -2507,13 +2507,13 @@ files = [ [[package]] name = "openai" -version = "1.31.0" +version = "1.31.1" description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.31.0-py3-none-any.whl", hash = "sha256:82044ee3122113f2a468a1f308a8882324d09556ba5348687c535d3655ee331c"}, - {file = "openai-1.31.0.tar.gz", hash = "sha256:54ae0625b005d6a3b895db2b8438dae1059cffff0cd262a26e9015c13a29ab06"}, + {file = "openai-1.31.1-py3-none-any.whl", hash = "sha256:a746cf070798a4048cfea00b0fc7cb9760ee7ead5a08c48115b914d1afbd1b53"}, + {file = "openai-1.31.1.tar.gz", hash = "sha256:a15266827de20f407d4bf9837030b168074b5b29acd54f10bb38d5f53e95f083"}, ] [package.dependencies] @@ -5093,6 +5093,17 @@ files = [ numpy = "*" torch = "2.3.0" +[[package]] +name = "xmltodict" +version = "0.13.0" +description = "Makes working with XML feel like you are working with JSON" +optional = false +python-versions = ">=3.4" +files = [ + {file = "xmltodict-0.13.0-py2.py3-none-any.whl", hash = "sha256:aa89e8fd76320154a40d19a0df04a4695fb9dc5ba977cbb68ab3e4eb225e7852"}, + {file = "xmltodict-0.13.0.tar.gz", hash = "sha256:341595a488e3e01a85a9d8911d8912fd922ede5fecc4dce437eb4b6c8d037e56"}, +] + [[package]] name = "yarl" version = "1.9.4" @@ -5218,4 +5229,4 @@ examples = ["asyncssh", "click", "httpx", "types-requests"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "bdbac14ad41e77ecb3d7a47f5a8892fad141aba1dc7cc14d9b5cd80a25527c5d" +content-hash = "a6b1bec0675b4f10e46a6cb31c76dc33e12c3f5c3bd9c6f5677083886a6274de" diff --git a/pyproject.toml b/pyproject.toml index 28d1219..a0dcd2a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ asyncssh = { version = "^2.14.2", optional = true } types-requests = { version = "^2.32.0.20240523", optional = true } click = { version = "^8.1.7", optional = true } httpx = { version = "^0.27.0", optional = true } +xmltodict = "^0.13.0" [tool.poetry.extras] examples = ["asyncssh", "types-requests", "click", "httpx"] diff --git a/rigging/__init__.py b/rigging/__init__.py index 84250ab..7c43310 100644 --- a/rigging/__init__.py +++ b/rigging/__init__.py @@ -22,6 +22,7 @@ ) from rigging.message import Message, MessageDict, Messages from rigging.model import Model, attr, element, make_primitive, wrapped +from rigging.prompt import Ctx, Prompt, prompt from rigging.tool import Tool __version__ = "1.3.0" @@ -56,6 +57,9 @@ "chats_to_elastic_data", "flatten_chats", "unflatten_chats", + "prompt", + "Prompt", + "Ctx", ] from loguru import logger diff --git a/rigging/chat.py b/rigging/chat.py index 1af5ba1..7655b64 100644 --- a/rigging/chat.py +++ b/rigging/chat.py @@ -28,6 +28,8 @@ from elasticsearch import AsyncElasticsearch from rigging.data import ElasticOpType + from rigging.prompt import Prompt + from rigging.util import P, R DEFAULT_MAX_ROUNDS = 5 """Maximum number of internal callback rounds to attempt during generation before giving up.""" @@ -203,7 +205,16 @@ def continue_(self, messages: t.Sequence[Message] | t.Sequence[MessageDict] | Me return self.fork(messages, include_all=True) def clone(self, *, only_messages: bool = False) -> Chat: - """Creates a deep copy of the chat.""" + """ + Creates a deep copy of the chat. + + Args: + only_messages: If True, only the messages will be cloned. + If False (default), the entire chat object will be cloned. + + Returns: + A new instance of Chat. + """ new = Chat( [m.model_copy() for m in self.messages], [m.model_copy() for m in self.generated], @@ -211,6 +222,11 @@ def clone(self, *, only_messages: bool = False) -> Chat: ) if not only_messages: new.metadata = deepcopy(self.metadata) + new.params = self.params.model_copy() if self.params is not None else None + new.stop_reason = self.stop_reason + new.usage = self.usage.model_copy() if self.usage is not None else None + new.extra = deepcopy(self.extra) + new.failed = self.failed return new def apply(self, **kwargs: str) -> Chat: @@ -371,7 +387,7 @@ async def to_elastic( """ from rigging.data import chats_to_elastic - return chats_to_elastic(self, index, client, op_type=op_type, create_index=create_index, **kwargs) + return await chats_to_elastic(self, index, client, op_type=op_type, create_index=create_index, **kwargs) # Callbacks @@ -564,7 +580,12 @@ def clone(self, *, only_messages: bool = False) -> ChatPipeline: Returns: A new instance of `ChatPipeline` that is a clone of the current instance. """ - new = ChatPipeline(self.generator, [], params=self.params, watch_callbacks=self.watch_callbacks) + new = ChatPipeline( + self.generator, + [], + params=self.params.model_copy() if self.params is not None else None, + watch_callbacks=self.watch_callbacks, + ) new.chat = self.chat.clone() if not only_messages: new.until_callbacks = self.until_callbacks.copy() @@ -573,8 +594,8 @@ def clone(self, *, only_messages: bool = False) -> ChatPipeline: new.inject_tool_prompt = self.inject_tool_prompt new.force_tool = self.force_tool new.metadata = deepcopy(self.metadata) - new.then_chat_callbacks = self.then_chat_callbacks.copy() - new.map_chat_callbacks = self.map_chat_callbacks.copy() + new.then_callbacks = self.then_callbacks.copy() + new.map_callbacks = self.map_callbacks.copy() return new def meta(self, **kwargs: t.Any) -> ChatPipeline: @@ -1157,3 +1178,72 @@ async def run_batch( chats = [s.chat for s in states if s.chat is not None] return await self._post_run(chats) + + # Generator iteration + + async def arun_over( + self, + *generators: Generator | str, + include_original: bool = True, + skip_failed: bool = False, + include_failed: bool = False, + ) -> ChatList: + """ + Executes the generation process across multiple generators. + + For each generator, this pending chat is cloned and the generator is replaced + before the run call. All callbacks and parameters are preserved. + + Parameters: + *generators: A sequence of generators to be used for the generation process. + include_original: Whether to include the original generator in the list of runs. + skip_failed: Enable to ignore any max rounds errors and return only successful chats. + include_failed: Enable to ignore max rounds errors and return both + successful and failed chats. + + Returns: + A list of generatated Chats. + """ + if skip_failed and include_failed: + raise ValueError("Cannot use both skip_failed and include_failed") + + _generators: list[Generator] = [g if isinstance(g, Generator) else get_generator(g) for g in generators] + if include_original: + _generators.append(self.generator) + + coros: list[t.Coroutine[t.Any, t.Any, Chat]] = [] + for generator in _generators: + sub = self.clone() + sub.generator = generator + coros.append(sub.run(allow_failed=skip_failed or include_failed)) + + chats = await asyncio.gather(*coros) + + if skip_failed: + chats = [c for c in chats if not c.failed] + + return ChatList(chats) + + # Prompt functions + + def prompt(self, func: t.Callable[P, t.Coroutine[None, None, R]]) -> Prompt[P, R]: + """ + Decorator to convert a function into a prompt bound to this pipeline. + + See [rigging.prompt.prompt][] for more information. + + Args: + func: The function to be converted into a prompt. + + Returns: + The prompt. + """ + from rigging.prompt import prompt + + return prompt(func, pipeline=self) + + async def run_prompt(self, prompt: Prompt[P, R], *args: P.args, **kwargs: P.kwargs) -> R: + return await prompt.run(*args, pipeline=self, **kwargs) + + async def run_prompt_many(self, prompt: Prompt[P, R], count: int, *args: P.args, **kwargs: P.kwargs) -> list[R]: + return await prompt.run_many(count, *args, pipeline=self, **kwargs) diff --git a/rigging/completion.py b/rigging/completion.py index 5d6fa8c..1028846 100644 --- a/rigging/completion.py +++ b/rigging/completion.py @@ -117,7 +117,6 @@ def restart(self, *, generator: t.Optional[Generator] = None, include_all: bool generator: The generator to use for the restarted completion. Otherwise the generator from the original CompletionPipeline will be used. include_all: Whether to include the generation before the next round. - Returns: The restarted completion. @@ -153,6 +152,11 @@ def clone(self, *, only_messages: bool = False) -> Completion: new = Completion(self.text, self.generated, self.generator) if not only_messages: new.metadata = deepcopy(self.metadata) + new.stop_reason = self.stop_reason + new.usage = self.usage.model_copy() if self.usage is not None else self.usage + new.extra = deepcopy(self.extra) + new.params = self.params.model_copy() if self.params is not None else self.params + new.failed = self.failed return new def meta(self, **kwargs: t.Any) -> Completion: @@ -388,12 +392,18 @@ def clone(self, *, only_text: bool = False) -> CompletionPipeline: Returns: A new instance of `CompletionPipeline` that is a clone of the current instance. """ - new = CompletionPipeline(self.generator, self.text, params=self.params, watch_callbacks=self.watch_callbacks) + new = CompletionPipeline( + self.generator, + self.text, + params=self.params.model_copy() if self.params is not None else None, + watch_callbacks=self.watch_callbacks, + ) if not only_text: new.until_callbacks = self.until_callbacks.copy() new.until_types = self.until_types.copy() new.metadata = deepcopy(self.metadata) new.then_callbacks = self.then_callbacks.copy() + new.map_callbacks = self.map_callbacks.copy() return new def meta(self, **kwargs: t.Any) -> CompletionPipeline: diff --git a/rigging/generator/base.py b/rigging/generator/base.py index 19f2a9e..f8fe50a 100644 --- a/rigging/generator/base.py +++ b/rigging/generator/base.py @@ -13,6 +13,8 @@ if t.TYPE_CHECKING: from rigging.chat import ChatPipeline, WatchChatCallback from rigging.completion import CompletionPipeline, WatchCompletionCallback + from rigging.prompt import Prompt + from rigging.util import P, R WatchCallbacks = t.Union[WatchChatCallback, WatchCompletionCallback] @@ -28,13 +30,11 @@ def __call__(self) -> type[Generator]: g_providers: dict[str, type[Generator] | LazyGenerator] = {} -# TODO: Ideally we flex this to support arbitrary -# generator params, but we'll limit things -# for now until we understand the use cases -# # TODO: We also would like to support N-style # parallel generation eventually -> need to # update our interfaces to support that + + class GenerateParams(BaseModel): """ Parameters for generating text using a language model. @@ -252,7 +252,7 @@ def watch(self, *callbacks: WatchCallbacks, allow_duplicates: bool = False) -> G allow_duplicates: Whether to allow (seemingly) duplicate callbacks to be added. ``` - def log(chats: list[Chat]) -> None: + async def log(chats: list[Chat]) -> None: ... pipeline.watch(log).run() @@ -398,6 +398,22 @@ def complete(self, text: str, params: GenerateParams | None = None) -> Completio return CompletionPipeline(self, text, params=params, watch_callbacks=completion_watch_callbacks) + def prompt(self, func: t.Callable[P, t.Coroutine[None, None, R]]) -> Prompt[P, R]: + """ + Decorator to convert a function into a prompt bound to this generator. + + See [rigging.prompt.prompt][] for more information. + + Args: + func: The function to be converted into a prompt. + + Returns: + The prompt. + """ + from rigging.prompt import prompt + + return prompt(func, generator=self) + @t.overload def chat( diff --git a/rigging/message.py b/rigging/message.py index a708dab..d0a26a2 100644 --- a/rigging/message.py +++ b/rigging/message.py @@ -133,7 +133,7 @@ def _remove_part(self, part: ParsedMessagePart) -> str: def _add_part(self, part: ParsedMessagePart) -> None: for existing in self.parts: - if part.slice_ == existing.slice_ and isinstance(part.model, type(existing.model)): + if part.slice_ == existing.slice_ and part.model.xml_tags() == existing.model.xml_tags(): return # We clearly already have this part defined if max(part.slice_.start, existing.slice_.start) < min(part.slice_.stop, existing.slice_.stop): raise ValueError("Incoming part overlaps with an existing part") @@ -368,7 +368,7 @@ def fit(cls, message: t.Union[Message, MessageDict, str]) -> Message: """Helper function to convert various common types to a Message object.""" if isinstance(message, str): return cls(role="user", content=message) - return cls(**message) if isinstance(message, dict) else message + return cls(**message) if isinstance(message, dict) else message.model_copy() @classmethod def apply_to_list(cls, messages: t.Sequence[Message], **kwargs: str) -> list[Message]: diff --git a/rigging/model.py b/rigging/model.py index 57abfdf..be77d5d 100644 --- a/rigging/model.py +++ b/rigging/model.py @@ -8,6 +8,7 @@ import typing as t from xml.etree import ElementTree as ET +import xmltodict # type: ignore from pydantic import ( BeforeValidator, SerializationInfo, @@ -16,7 +17,6 @@ field_serializer, field_validator, ) -from pydantic.alias_generators import to_snake from pydantic_xml import BaseXmlModel from pydantic_xml import attr as attr from pydantic_xml import element as element @@ -24,6 +24,7 @@ from pydantic_xml.typedefs import EntityLocation, NsMap from rigging.error import MissingModelError +from rigging.util import escape_xml, to_xml_tag, unescape_xml if t.TYPE_CHECKING: from pydantic_xml.element import SearchMode # type: ignore [attr-defined] @@ -42,25 +43,6 @@ BASIC_TYPES = [int, str, float, bool] -def escape_xml(xml_string: str) -> str: - prepared = re.sub(r"&(?!(?:amp|lt|gt|apos|quot);)", "&", xml_string) - - return prepared - - -def unescape_xml(xml_string: str) -> str: - # We only expect to use this in our "simple" - # models, but I'd like a better long-term solution - - unescaped = re.sub(r"&", "&", xml_string) - unescaped = re.sub(r"<", "<", unescaped) - unescaped = re.sub(r">", ">", unescaped) - unescaped = re.sub(r"'", "'", unescaped) - unescaped = re.sub(r""", '"', unescaped) - - return unescaped - - class XmlTagDescriptor: def __get__(self, _: t.Any, owner: t.Any) -> str: mro_iter = iter(owner.mro()) @@ -79,7 +61,7 @@ def __get__(self, _: t.Any, owner: t.Any) -> str: if "[" in cls.__name__: return t.cast(str, parent.__xml_tag__) - return to_snake(cls.__name__).replace("_", "-") + return to_xml_tag(cls.__name__) class Model(BaseXmlModel): @@ -118,6 +100,8 @@ def to_pretty_xml(self) -> str: pretty_encoded_xml = ET.tostring(tree, short_empty_elements=False).decode() if self.__class__.is_simple(): + # We only expect to use this in our "simple" + # models, but I'd like a better long-term solution return unescape_xml(pretty_encoded_xml) else: return pretty_encoded_xml @@ -167,12 +151,18 @@ def xml_example(cls) -> str: Models should typically override this method to provide a more complex example. - By default, this method just returns the XML tags for the class. + By default, this method returns a hollow XML scaffold one layer deep. Returns: A string containing the XML representation of the class. """ - return cls.xml_tags() + schema = cls.model_json_schema() + properties = schema["properties"] + structure = {cls.__xml_tag__: {field: None for field in properties}} + xml_string = xmltodict.unparse( + structure, pretty=True, full_document=False, indent=" ", short_empty_elements=True + ) + return t.cast(str, xml_string) # Bad type hints in xmltodict @classmethod def ensure_valid(cls) -> None: @@ -299,7 +289,7 @@ class Primitive(Model, t.Generic[PrimitiveT]): def make_primitive( name: str, - type_: PrimitiveT = str, + type_: type[PrimitiveT] = str, # type: ignore [assignment] *, tag: str | None = None, doc: str | None = None, @@ -328,7 +318,7 @@ def _validate(value: str) -> str: return create_model( name, - __base__=Primitive[type_], + __base__=Primitive[type_], # type: ignore __doc__=doc, __cls_kwargs__={"tag": tag}, content=(t.Annotated[type_, BeforeValidator(lambda x: x.strip() if isinstance(x, str) else x)], ...), diff --git a/rigging/prompt.py b/rigging/prompt.py index af8f013..3fa4786 100644 --- a/rigging/prompt.py +++ b/rigging/prompt.py @@ -1,182 +1,788 @@ from __future__ import annotations +import asyncio +import dataclasses import inspect +import re import typing as t -from jinja2 import Environment, meta -from pydantic import BaseModel, computed_field, model_validator -from typing_extensions import ParamSpec +from jinja2 import Environment, StrictUndefined, meta +from pydantic import ValidationError +from rigging.chat import Chat from rigging.generator.base import GenerateParams, Generator, get_generator -from rigging.model import Model +from rigging.message import Message +from rigging.model import Model, SystemErrorModel, ValidationErrorModel, make_primitive +from rigging.util import P, R, escape_xml, to_snake, to_xml_tag if t.TYPE_CHECKING: - from rigging.chat import ChatPipeline + from rigging.chat import ChatPipeline, WatchChatCallback -DEFAULT_DOC = "You will convert the following inputs to outputs." +DEFAULT_DOC = "Convert the following inputs to outputs ({func_name})." +"""Default docstring if none is provided to a prompt function.""" + +DEFAULT_MAX_ROUNDS = 3 +"""Default maximum number of rounds for a prompt to run until outputs are parsed.""" + +# Annotation + + +@dataclasses.dataclass +class Ctx: + """ + Used in type annotations to provide additional context for the prompt construction. + + You can use this annotation on inputs and ouputs to prompt functions. + + ``` + tag_override = Annotated[str, Ctx(tag="custom_tag", ...)] + ``` + """ + + tag: str | None = None + prefix: str | None = None + example: str | None = None -P = ParamSpec("P") -R = t.TypeVar("R") -BASIC_TYPES = [int, float, str, bool, list, dict, set, tuple, type(None)] # Utilities -def get_undefined_values(template: str) -> set[str]: +def unwrap_annotated(annotation: t.Any) -> tuple[t.Any, t.Optional[Ctx]]: + if t.get_origin(annotation) is t.Annotated: + base_type, *meta = t.get_args(annotation) + for m in meta: + if isinstance(m, Ctx): + return base_type, m + return base_type, None + return annotation, None + + +def get_undeclared_variables(template: str) -> set[str]: env = Environment() parsed_template = env.parse(template) return meta.find_undeclared_variables(parsed_template) -def format_parameter(param: inspect.Parameter, value: t.Any) -> str: - name = param.name +def make_parameter( + annotation: t.Any, *, name: str = "", kind: inspect._ParameterKind = inspect.Parameter.VAR_KEYWORD +) -> inspect.Parameter: + return inspect.Parameter(name=name, kind=kind, annotation=annotation) + + +# Function Inputs + + +@dataclasses.dataclass +class Input: + name: str + context: Ctx + + @property + def tag(self) -> str: + return self.context.tag or to_xml_tag(self.name) + + def _prefix(self, xml: str) -> str: + if self.context.prefix: + return f"{self.context.prefix}\n{xml}" + return xml + + def to_str(self, value: t.Any) -> str: + raise NotImplementedError + + def to_xml(self, value: t.Any) -> str: + value_str = self.to_str(value) + if "\n" in value_str: + value_str = f"\n{value_str}\n" + return self._prefix(f"<{self.tag}>{escape_xml(value_str)}") + + +@dataclasses.dataclass +class BasicInput(Input): + def to_str(self, value: t.Any) -> str: + if not isinstance(value, (int, float, str, bool)): + raise ValueError(f"Value must be a basic type, got: {type(value)}") - if isinstance(value, str): - if "\n" in value: - value = f"\n{value.strip()}\n" - return f"<{name}>{value}" + return str(value) - if isinstance(value, (int, float)): - return f"<{name}>{value}" - if isinstance(value, bool): - return f"<{name}>{'true' if value else 'false'}" +@dataclasses.dataclass +class ModelInput(Input): + def to_str(self, value: t.Any) -> str: + if not isinstance(value, Model): + raise ValueError(f"Value must be a Model instance, got: {type(value)}") - if isinstance(value, Model): return value.to_pretty_xml() - if isinstance(value, (list, set)): - type_args = t.get_args(param.annotation) + def to_xml(self, value: t.Any) -> str: + return self._prefix(self.to_str(value)) + + +@dataclasses.dataclass +class ListInput(Input): + interior: Input + + def to_str(self, value: list[t.Any]) -> str: + return "\n\n".join(self.interior.to_str(v) for v in value) + - xml = f"<{name}>\n" - for item in value: - pass +@dataclasses.dataclass +class DictInput(Input): + interior: Input - raise ValueError(f"Unsupported parameter: {param}: '{value}'") + def to_str(self, value: t.Any) -> str: + if not isinstance(value, dict): + raise ValueError(f"Value must be a dictionary, got: {type(value)}") + if not all(isinstance(k, str) for k in value.keys()): + raise ValueError("Dictionary keys must be strings") + return "\n".join(f"<{k}>{self.interior.to_str(v)}" for k, v in value.items()) -def check_valid_function(func: t.Callable[P, t.Coroutine[None, None, R]]) -> None: - signature = inspect.signature(func) - for param in signature.parameters.values(): - error_name = f"{func.__name__}({param})" +def parse_parameter(param: inspect.Parameter, error_name: str) -> Input: + if param.kind not in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ): + raise TypeError(f"Parameters must be positional or keyword {error_name}") - if param.kind not in ( - inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.KEYWORD_ONLY, - ): - raise TypeError(f"Parameters must be positional or keyword {error_name}") + if param.annotation in [None, inspect.Parameter.empty]: + raise TypeError(f"All parameters require type annotations {error_name}") - if param.annotation in [None, inspect.Parameter.empty]: - raise TypeError(f"All parameters require type annotations {error_name}") + annotation, context = unwrap_annotated(param.annotation) - origin = t.get_origin(param.annotation) - if origin is not None: - # Check for a dict[str, str] - if origin is dict: - if param.annotation.__args__[0] != (str, str): - raise TypeError(f"Dicts must have str keys {error_name}") - raise TypeError(f"Parameters cannot be generic, lists, sets, or dicts {error_name}") + if annotation in [int, float, str, bool]: + return BasicInput(param.name, context or Ctx()) - if param.annotation in [int, float, str, bool]: - continue + if inspect.isclass(annotation) and issubclass(annotation, Model): + return ModelInput(param.name, context or Ctx()) - if issubclass(param.annotation, Model): - continue + if t.get_origin(annotation) is list: + if not param.name: + raise TypeError(f"Nested list parameters are not supported: {error_name}") - raise TypeError( - f"Invalid parameter type: {param.annotation}, must be one of int, bool, str, float or rg.Model ({func.__name__}#{param.name})" + args = t.get_args(annotation) + if not args: + raise TypeError(f"List param must be fully typed: {error_name}") + + arg_type, arg_context = unwrap_annotated(args[0]) + return ListInput( + param.name, arg_context or context or Ctx(), parse_parameter(make_parameter(arg_type), error_name) ) - if signature.return_annotation in [None, inspect.Parameter.empty]: - raise TypeError(f"Return type annotation is required ({func.__name__})") + elif t.get_origin(annotation) is dict: + if not param.name: + raise TypeError(f"Nested dict parameters are not supported: {error_name}") + + args = t.get_args(annotation) + if not args or len(args) != 2: + raise TypeError(f"Dict param must be fully typed: {error_name}") + if args[0] is not str: + raise TypeError(f"Dict param keys must be strings: {error_name}") + + return DictInput(param.name, context or Ctx(), parse_parameter(make_parameter(args[1]), error_name)) + + raise TypeError(f"Unsupported parameter type: {error_name}") + + +# Function Outputs + + +@dataclasses.dataclass +class Output: + id: str + context: Ctx + + @property + def tag(self) -> str: + return self.context.tag or to_xml_tag(self.id) + + def _prefix(self, xml: str) -> str: + if self.context.prefix: + return f"{self.context.prefix}\n{xml}" + return xml + + def guidance(self) -> str: + return "Produce the following output:" + + def to_format(self) -> str: + tag = self.context.tag or self.tag + return self._prefix(f"<{tag}>{escape_xml(self.context.example or '')}") + + def from_chat(self, chat: Chat) -> t.Any: + raise NotImplementedError + + +@dataclasses.dataclass +class ChatOutput(Output): + def from_chat(self, chat: Chat) -> t.Any: + return chat + + +@dataclasses.dataclass +class BasicOutput(Output): + type_: type[t.Any] # TODO: We should be able to scope this down + + def from_chat(self, chat: Chat) -> t.Any: + Temp = make_primitive("Model", self.type_, tag=self.context.tag or self.tag) + return chat.last.parse(Temp).content + + +@dataclasses.dataclass +class BasicListOutput(BasicOutput): + def guidance(self) -> str: + return "Produce the following output for each item:" + + def from_chat(self, chat: Chat) -> t.Any: + Model = make_primitive("Model", self.type_, tag=self.context.tag or self.tag) + return [m.content for m in chat.last.parse_set(Model)] + + +@dataclasses.dataclass +class ModelOutput(Output): + type_: type[Model] + + def to_format(self) -> str: + return self.type_.xml_example() + + def from_chat(self, chat: Chat) -> t.Any: + return chat.last.parse(self.type_) + + +@dataclasses.dataclass +class ModelListOutput(ModelOutput): + def guidance(self) -> str: + return "Produce the following output for each item:" + + def from_chat(self, chat: Chat) -> t.Any: + return chat.last.parse_set(self.type_) + + +@dataclasses.dataclass +class TupleOutput(Output): + interiors: list[Output] + + @property + def real_interiors(self) -> list[Output]: + return [i for i in self.interiors if not isinstance(i, ChatOutput)] + + @property + def wrapped(self) -> bool: + # Handles cases where we are using a tuple just to + # capture a Chat along with a real output, in this + # case we should fall through for most of the work + # + # () -> tuple[Chat, ...] + return len(self.real_interiors) == 1 + + def guidance(self) -> str: + if self.wrapped: + return self.real_interiors[0].guidance() + return "Produce the following outputs:" + + def to_format(self) -> str: + if self.wrapped: + return self.real_interiors[0].to_format() + return self._prefix("\n\n".join(i.to_format() for i in self.real_interiors)) + + def from_chat(self, chat: Chat) -> t.Any: + return tuple(i.from_chat(chat) for i in self.interiors) + + +@dataclasses.dataclass +class DataclassOutput(TupleOutput): + type_: type[t.Any] - if not isinstance(signature.return_annotation, tuple): - return + def from_chat(self, chat: Chat) -> t.Any: + return self.type_(*super().from_chat(chat)) -def build_template(func: t.Callable) -> str: - docstring = func.__doc__ or DEFAULT_DOC - docstring = inspect.cleandoc(docstring) +def parse_output(annotation: t.Any, error_name: str, *, allow_nested: bool = True) -> Output: + from rigging.chat import Chat - signature = inspect.signature(func) + if annotation in [None, inspect.Parameter.empty]: + raise TypeError(f"Return type annotation is required ({error_name})") + + # Unwrap any annotated types + annotation, context = unwrap_annotated(annotation) + + if annotation == Chat: + # Use a special subclass here -> args don't matter + return ChatOutput(id="chat", context=context or Ctx()) + + if annotation in [int, float, str, bool]: + return BasicOutput(id=annotation.__name__, context=context or Ctx(), type_=annotation) + + if t.get_origin(annotation) is list: + if not allow_nested: + raise TypeError(f"Nested list outputs are not supported ({error_name})") + + args = t.get_args(annotation) + if not args: + raise TypeError(f"List return type must be fully specified ({error_name})") + + arg_type, arg_context = unwrap_annotated(args[0]) + + if arg_type in [int, float, str, bool]: + return BasicListOutput(id=arg_type.__name__, context=arg_context or context or Ctx(), type_=arg_type) + + if inspect.isclass(arg_type) and issubclass(arg_type, Model): + return ModelListOutput(id=arg_type.__name__, context=arg_context or context or Ctx(), type_=arg_type) + + if t.get_origin(annotation) is tuple: + if not allow_nested: + raise TypeError(f"Nested tuple outputs are not supported ({error_name})") + + args = t.get_args(annotation) + if not args: + raise TypeError(f"Tuple return type must be fully specified ({error_name})") + + tuple_interiors = [parse_output(arg, error_name, allow_nested=False) for arg in args] + + if len({i.tag for i in tuple_interiors}) != len(tuple_interiors): + raise TypeError( + f"Tuple return annotations must have unique internal types\n" + "or use Annotated[..., Context(tag=...)] overrides to\n" + f"make them differentiable ({error_name})" + ) + + return TupleOutput(id="tuple", context=context or Ctx(), interiors=tuple_interiors) + + if dataclasses.is_dataclass(annotation): + interior_annotations: list[t.Any] = [] + for field in dataclasses.fields(annotation): + field_annotation, field_context = unwrap_annotated(field.type) + if field_context is None: + field_annotation = t.Annotated[field_annotation, Ctx(tag=to_xml_tag(field.name))] + interior_annotations.append(field_annotation) + + dataclass_interiors: list[Output] = [] + for field, field_annotation in zip(dataclasses.fields(annotation), interior_annotations): + interior = parse_output(field_annotation, f"{error_name}#{field.name}", allow_nested=False) + if interior is None: + raise TypeError(f"Dataclass field type is invalid ({error_name}#{field.name}") + dataclass_interiors.append(interior) + + if len({i.tag for i in dataclass_interiors}) != len(dataclass_interiors): + raise TypeError( + f"Dataclass return annotations must have unique internal types\n" + "or use Annotated[..., Context(tag=...)] overrides to\n" + f"make them differentiable ({error_name})" + ) + + return DataclassOutput( + id=annotation.__name__, type_=annotation, context=context or Ctx(), interiors=dataclass_interiors + ) + + # This has to come after our list/tuple checks as they pass isclass + if inspect.isclass(annotation) and issubclass(annotation, Model): + return ModelOutput(id=annotation.__name__, context=context or Ctx(), type_=annotation) + + raise TypeError(f"Unsupported return type: {error_name}") # Prompt -class Prompt(BaseModel, t.Generic[P, R]): - _func: t.Callable[P, t.Coroutine[None, None, R]] +@dataclasses.dataclass +class Prompt(t.Generic[P, R]): + """ + Prompts wrap hollow functions and create structured chat interfaces for + passing inputs into a ChatPipeline and parsing outputs. + """ + + func: t.Callable[P, t.Coroutine[None, None, R]] + """The function that the prompt wraps. This function should be a coroutine.""" + + attempt_recovery: bool = True + """Whether the prompt should attempt to recover from errors in output parsing.""" + drop_dialog: bool = True + """When attempting recovery, whether to drop intermediate dialog while parsing was being resolved.""" + max_rounds: int = DEFAULT_MAX_ROUNDS + """The maximum number of rounds the prompt should try to reparse outputs.""" + + inputs: list[Input] = dataclasses.field(default_factory=list) + """The structured input handlers for the prompt.""" + output: Output = ChatOutput(id="chat", context=Ctx()) + """The structured output handler for the prompt.""" + + watch_callbacks: list[WatchChatCallback] = dataclasses.field(default_factory=list) + """Callbacks to be passed any chats produced while executing this prompt.""" + params: GenerateParams | None = None + """The parameters to be used when generating chats for this prompt.""" _generator_id: str | None = None _generator: Generator | None = None _pipeline: ChatPipeline | None = None - _params: GenerateParams | None = None - @model_validator(mode="after") - def check_valid_function(self) -> Prompt[P, R]: - check_valid_function(self._func) - return self + def __post_init__(self) -> None: + signature = inspect.signature(self.func) + undeclared = get_undeclared_variables(self.docstring) + + for param in signature.parameters.values(): + if param.name in undeclared: + continue + error_name = f"{self.func.__name__}({param})" + self.inputs.append(parse_parameter(param, error_name)) + + if len({i.tag for i in self.inputs}) != len(self.inputs): + raise TypeError("All input parameters must have unique names/tags") + + error_name = f"{self.func.__name__}() -> {signature.return_annotation}" + self.output = parse_output(signature.return_annotation, error_name) + + @property + def docstring(self) -> str: + """The docstring for the prompt function.""" + # Guidance is taken from https://github.com/outlines-dev/outlines/blob/main/outlines/prompts.py + docstring = self.func.__doc__ or DEFAULT_DOC.format(func_name=self.func.__name__) + docstring = inspect.cleandoc(docstring) + docstring = re.sub(r"(?![\r\n])(\b\s+)", " ", docstring) + return docstring - @computed_field # type: ignore [misc] @property def template(self) -> str: - return "" + """The dynamic jinja2 template for the prompt function.""" + text = f"{self.docstring}\n" + + for input_ in self.inputs: + text += "\n{{ " + to_snake(input_.tag) + " }}\n" + + if self.output is None or isinstance(self.output, ChatOutput): + return text + + text += f"\n{self.output.guidance()}\n" + text += f"\n{self.output.to_format()}\n" + + return text @property def pipeline(self) -> ChatPipeline | None: + """If available, the resolved Chat Pipeline for the prompt.""" if self._pipeline is not None: - return self._pipeline.with_(params=self._params) + return self._pipeline + + if self._generator is None and self._generator_id is not None: + self._generator = get_generator(self._generator_id) - if self._generator is None: - if self._generator_id is None: - raise ValueError( - "You cannot execute this prompt ad-hoc. No pipeline, generator, or generator_id was provided." + if self._generator is not None: + self._pipeline = self._generator.chat() + return self._pipeline + + return None + + def _until_parsed(self, message: Message) -> tuple[bool, list[Message]]: + should_continue: bool = False + generated: list[Message] = [message] + + if self.output is None or isinstance(self.output, ChatOutput): + return (should_continue, generated) + + try: + # A bit weird, but we need from_chat to properly handle + # wrapping Chat output types inside lists/dataclasses + self.output.from_chat(Chat([], generated=[message])) + except ValidationError as e: + should_continue = True + generated.append( + Message.from_model( + ValidationErrorModel(content=str(e)), + suffix="Rewrite your entire message with all the required elements.", ) + ) + except Exception as e: + should_continue = True + generated.append( + Message.from_model( + SystemErrorModel(content=str(e)), + suffix="Rewrite your entire message with all the required elements.", + ) + ) + + return (should_continue, generated) + + def clone(self, *, skip_callbacks: bool = False) -> Prompt[P, R]: + """ + Creates a deep copy of this prompt. + + Args: + skip_callbacks: Whether to skip copying the watch callbacks. + + Returns: + A new instance of the prompt. + """ + new = Prompt( + func=self.func, + _pipeline=self.pipeline, + params=self.params.model_copy() if self.params is not None else None, + attempt_recovery=self.attempt_recovery, + drop_dialog=self.drop_dialog, + max_rounds=self.max_rounds, + ) + if not skip_callbacks: + new.watch_callbacks = self.watch_callbacks.copy() + return new - self._generator = get_generator(self._generator_id) + def with_(self, params: t.Optional[GenerateParams] = None, **kwargs: t.Any) -> Prompt[P, R]: + """ + Assign specific generation parameter overloads for this prompt. - return self._generator.chat(params=self._params) + Note: + This will trigger a `clone` if overload params have already been set. - def clone(self) -> Prompt[P, R]: - return Prompt(_func=self._func, _pipeline=self.pipeline) + Args: + params: The parameters to set for the underlying chat pipeline. + **kwargs: An alternative way to pass parameters as keyword arguments. - def with_(self, params: t.Optional[GenerateParams] = None, **kwargs: t.Any) -> Prompt[P, R]: + Returns: + Prompt with the updated params. + """ if params is None: params = GenerateParams(**kwargs) - if self._params is not None: + if self.params is not None: new = self.clone() - new._params = self._params.merge_with(params) + new.params = self.params.merge_with(params) return new self.params = params return self + # We could put these params into the decorator, but it makes it + # less flexible when we want to build gateway interfaces into + # creating a prompt from other code. + + def set_( + self, attempt_recovery: bool | None = None, drop_dialog: bool | None = None, max_rounds: int | None = None + ) -> Prompt[P, R]: + """ + Helper to allow updates to the parsing configuration. + + Args: + attempt_recovery: Whether the prompt should attempt to recover from errors in output parsing. + drop_dialog: When attempting recovery, whether to drop intermediate dialog while parsing was being resolved. + max_rounds: The maximum number of rounds the prompt should try to reparse outputs. + + Returns: + The current instance of the chat. + """ + self.attempt_recovery = attempt_recovery or self.attempt_recovery + self.drop_dialog = drop_dialog or self.drop_dialog + self.max_rounds = max_rounds or self.max_rounds + return self + + def watch(self, *callbacks: WatchChatCallback) -> Prompt[P, R]: + """ + Registers a callback to monitor any chats produced for this prompt + + Args: + *callbacks: The callback functions to be executed. + + ``` + async def log(chats: list[Chat]) -> None: + ... + + @rg.prompt().watch(log) + async def summarize(text: str) -> str: + ... + + summarize(...) + ``` + + Returns: + The current instance of the chat. + """ + for callback in callbacks: + if callback not in self.watch_callbacks: + self.watch_callbacks.append(callback) + return self + def render(self, *args: P.args, **kwargs: P.kwargs) -> str: - pass + """ + Pass the arguments to the jinja2 template and render the full prompt. + """ + env = Environment( + trim_blocks=True, + lstrip_blocks=True, + keep_trailing_newline=True, + undefined=StrictUndefined, + ) + jinja_template = env.from_string(self.template) + + signature = inspect.signature(self.func) + bound_args = signature.bind(*args, **kwargs) + bound_args.apply_defaults() + + for input_ in self.inputs: + bound_args.arguments[to_snake(input_.tag)] = input_.to_xml(bound_args.arguments[input_.name]) + + return jinja_template.render(**bound_args.arguments) + + def process(self, chat: Chat) -> R: + """ + Attempt to parse the output from a chat into the expected return type. + """ + return self.output.from_chat(chat) # type: ignore + + async def run(self, *args: P.args, pipeline: ChatPipeline | None = None, **kwargs: P.kwargs) -> R: + """ + Use the prompt to run the function with the provided arguments and return the output. + + Args: + *args: The positional arguments for the prompt function. + pipeline: An optional pipeline to use for the prompt. + **kwargs: The keyword arguments for the prompt function. + + Returns: + The output of the prompt function. + """ + return (await self.run_many(1, *args, pipeline=pipeline, **kwargs))[0] + + async def run_many( + self, count: int, *args: P.args, pipeline: ChatPipeline | None = None, **kwargs: P.kwargs + ) -> list[R]: + """ + Use the prompt to run the function multiple times with the provided arguments and return the output. + + Args: + count: The number of times to run the prompt. + *args: The positional arguments for the prompt function. + pipeline: An optional pipeline to use for the prompt. + **kwargs: The keyword arguments for the prompt function. + + Returns: + The outputs of the prompt function. + """ + pipeline = pipeline or self.pipeline + if pipeline is None: + raise RuntimeError( + "Prompt cannot be executed as a standalone function without being assigned a pipeline or generator" + ) + + content = self.render(*args, **kwargs) + pipeline = ( + pipeline.fork(content) + .until( + self._until_parsed, + attempt_recovery=self.attempt_recovery, + drop_dialog=self.drop_dialog, + max_rounds=self.max_rounds, + ) + .with_(self.params) + ) + chats = await pipeline.run_many(count) - async def run(self, *args: P.args, **kwargs: P.kwargs) -> R: - pass + coros = [watch(chats) for watch in set(self.watch_callbacks + pipeline.watch_callbacks)] + await asyncio.gather(*coros) - async def run_many(self, *args: P.args, **kwargs: P.kwargs) -> list[R]: - pass + return [self.process(chat) for chat in chats] __call__ = run +# Decorator + + +@t.overload def prompt( - *, pipeline: ChatPipeline | None = None, generator: Generator | None = None, generator_id: str | None = None + func: None = None, + /, + *, + pipeline: ChatPipeline | None = None, + generator: Generator | None = None, + generator_id: str | None = None, ) -> t.Callable[[t.Callable[P, t.Coroutine[None, None, R]]], Prompt[P, R]]: - if sum(arg is not None for arg in (pipeline, generator, generator_id)) > 1: - raise ValueError("Only one of pipeline, generator, or generator_id can be provided") + ... - def decorator(func: t.Callable[P, t.Coroutine[None, None, R]]) -> Prompt[P, R]: - return Prompt[P, R](_func=func, _generator_id=generator_id, _pipeline=pipeline, _generator=generator) - return decorator +@t.overload +def prompt( + func: t.Callable[P, t.Coroutine[None, None, R]], + /, + *, + pipeline: ChatPipeline | None = None, + generator: Generator | None = None, + generator_id: str | None = None, +) -> Prompt[P, R]: + ... + +def prompt( + func: t.Callable[P, t.Coroutine[None, None, R]] | None = None, + /, + *, + pipeline: ChatPipeline | None = None, + generator: Generator | None = None, + generator_id: str | None = None, +) -> t.Callable[[t.Callable[P, t.Coroutine[None, None, R]]], Prompt[P, R]] | Prompt[P, R]: + """ + Convert a hollow function into a Prompt, which can be called directly or passed a + chat pipeline to execute the function and parse the outputs. + + ``` + from dataclasses import dataclass + import rigging as rg + + @dataclass + class ExplainedJoke: + chat: rg.Chat + setup: str + punchline: str + explanation: str + + @rg.prompt(generator_id="gpt-3.5-turbo") + async def write_joke(topic: str) -> ExplainedJoke: + \"""Write a joke.\""" + ... + + await write_joke("programming") + + Note: + A docstring is not required, but this can be used to provide guidance to the model, or + even handle any number of input transormations. Any input parameter which is not + handled inside the docstring will be automatically added and formatted internally. + + Note: + Output parameters can be basic types, dataclasses, rigging models, lists, or tuples. + Internal inspection will attempt to ensure your output types are valid, but there is + no guarantee of complete coverage/safety. It's recommended to check + [rigging.prompt.Prompt.template][] to inspect the generated jinja2 template. + + Note: + If you annotate the return value of the function as a [rigging.chat.Chat][] object, + then no output parsing will take place and you can parse objects out manually. + + You can also use Chat in any number of type annotation inside tuples or dataclasses. + All instances will be filled with the final chat object transparently. + + Note: + All input parameters and output types can be annotated with the [rigging.prompt.Ctx][] annotation + to provide additional context for the prompt. This can be used to override the xml tag, provide + a prefix string, or example content which will be placed inside output xml tags. + + In the case of output parameters, especially in tuples, you might have xml tag collisions + between the same basic types. Manually annotating xml tags with [rigging.prompt.Ctx][] is + recommended. + + Args: + func: The function to convert into a prompt. + pipeline: An optional pipeline to use for the prompt. + generator: An optional generator to use for the prompt. + generator_id: An optional generator id to use for the prompt. + + Returns: + A prompt instance or a function that can be used to create a prompt. + """ + if sum(arg is not None for arg in (pipeline, generator, generator_id)) > 1: + raise ValueError("Only one of pipeline, generator, or generator_id can be provided") + + def make_prompt(func: t.Callable[P, t.Coroutine[None, None, R]]) -> Prompt[P, R]: + return Prompt[P, R]( + func=func, + _generator_id=generator_id, + _pipeline=pipeline, + _generator=generator, + ) -@prompt() -async def testing() -> None: - pass + if func is not None: + return make_prompt(func) + return make_prompt diff --git a/rigging/util.py b/rigging/util.py new file mode 100644 index 0000000..9137788 --- /dev/null +++ b/rigging/util.py @@ -0,0 +1,32 @@ +import re +import typing as t + +from pydantic import alias_generators +from typing_extensions import ParamSpec + +P = ParamSpec("P") +R = t.TypeVar("R") + + +def escape_xml(xml_string: str) -> str: + prepared = re.sub(r"&(?!(?:amp|lt|gt|apos|quot);)", "&", xml_string) + + return prepared + + +def unescape_xml(xml_string: str) -> str: + unescaped = re.sub(r"&", "&", xml_string) + unescaped = re.sub(r"<", "<", unescaped) + unescaped = re.sub(r">", ">", unescaped) + unescaped = re.sub(r"'", "'", unescaped) + unescaped = re.sub(r""", '"', unescaped) + + return unescaped + + +def to_snake(text: str) -> str: + return alias_generators.to_snake(text) + + +def to_xml_tag(text: str) -> str: + return to_snake(text).replace("_", "-").strip("-") diff --git a/tests/test_prompt.py b/tests/test_prompt.py new file mode 100644 index 0000000..cc12e64 --- /dev/null +++ b/tests/test_prompt.py @@ -0,0 +1,239 @@ +from dataclasses import dataclass +from textwrap import dedent +from typing import Annotated + +import pytest + +import rigging as rg +from rigging.chat import Chat + +# mypy: disable-error-code=empty-body + + +def test_prompt_render_docstring_parse() -> None: + @rg.prompt + async def foo(name: str) -> str: + """Say hello.""" + ... + + assert foo.docstring == "Say hello." + + @rg.prompt + async def bar(name: str) -> str: + """ + Say hello.""" + ... + + assert bar.docstring == "Say hello." + + @rg.prompt + async def baz(name: str) -> str: + """ + Say \ + hello. + + """ + ... + + assert baz.docstring == "Say hello." + + +def test_basic_prompt_render() -> None: + @rg.prompt + async def hello(name: str) -> str: + """Say hello.""" + ... + + rendered = hello.render("Alice") + assert rendered == dedent( + """\ + Say hello. + + Alice + + Produce the following output: + + + """ + ) + + +def test_prompt_render_with_docstring_variables() -> None: + @rg.prompt + async def greet(name: str, greeting: str = "Hello") -> str: + """Say '{{ greeting }}' to {{ name }}.""" + ... + + rendered = greet.render("Bob") + assert rendered == dedent( + """\ + Say 'Hello' to Bob. + + Produce the following output: + + + """ + ) + + +def test_prompt_render_with_model_output() -> None: + class Person(rg.Model): + name: str = rg.element() + age: int = rg.element() + + @rg.prompt + async def create_person(name: str, age: int) -> Person: + """Create a person.""" + ... + + rendered = create_person.render("Alice", 30) + assert rendered == dedent( + """\ + Create a person. + + Alice + + 30 + + Produce the following output: + + + + + + """ + ) + + +def test_prompt_render_with_list_output() -> None: + @rg.prompt + async def generate_numbers(count: int) -> list[int]: + """Generate a list of numbers.""" + ... + + rendered = generate_numbers.render(5) + assert rendered == dedent( + """\ + Generate a list of numbers. + + 5 + + Produce the following output for each item: + + + """ + ) + + +def test_prompt_render_with_tuple_output() -> None: + @rg.prompt + async def create_user(username: str) -> tuple[str, int]: + """Create a new user.""" + ... + + rendered = create_user.render("johndoe") + assert rendered == dedent( + """\ + Create a new user. + + johndoe + + Produce the following outputs: + + + + + """ + ) + + +def test_prompt_render_with_tuple_output_ctx() -> None: + @rg.prompt + async def create_user(username: str) -> tuple[Annotated[str, rg.Ctx(tag="id")], int]: + """Create a new user.""" + ... + + rendered = create_user.render("johndoe") + assert rendered == dedent( + """\ + Create a new user. + + johndoe + + Produce the following outputs: + + + + + """ + ) + + +def test_prompt_render_with_dataclass_output() -> None: + @dataclass + class User: + username: str + email: str + age: int + + @rg.prompt + async def register_user(username: str, email: str, age: int) -> User: + """Register a new user: {{ username}}.""" + ... + + rendered = register_user.render("johndoe", "johndoe@example.com", 25) + assert rendered == dedent( + """\ + Register a new user: johndoe. + + johndoe@example.com + + 25 + + Produce the following outputs: + + + + + + + """ + ) + + +def test_prompt_render_with_chat_return() -> None: + @rg.prompt + async def foo(input_: str) -> Chat: + """Do something.""" + ... + + rendered = foo.render("bar") + assert rendered == dedent( + """\ + Do something. + + bar + """ + ) + + +def test_prompt_parse_fail_nested_input() -> None: + async def foo(arg: list[list[str]]) -> Chat: + ... + + with pytest.raises(TypeError): + rg.prompt(foo) + + async def bar(arg: tuple[int, str, tuple[str]]) -> Chat: + ... + + with pytest.raises(TypeError): + rg.prompt(bar) + + +def test_prompt_parse_fail_unique_ouput() -> None: + async def foo(arg: int) -> tuple[str, str]: + ... + + with pytest.raises(TypeError): + rg.prompt(foo)