diff --git a/docs/api/prompt.md b/docs/api/prompt.md
new file mode 100644
index 0000000..614c644
--- /dev/null
+++ b/docs/api/prompt.md
@@ -0,0 +1 @@
+::: rigging.prompt
\ No newline at end of file
diff --git a/docs/topics/prompt-functions.md b/docs/topics/prompt-functions.md
new file mode 100644
index 0000000..e69de29
diff --git a/mkdocs.yml b/mkdocs.yml
index 57ae1ee..c9a2bdc 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -11,6 +11,7 @@ nav:
- Models: topics/models.md
- Generators: topics/generators.md
- Chats and Messages: topics/chats-and-messages.md
+ - Prompt Functions: topics/prompt-functions.md
- Completions: topics/completions.md
- Callbacks and Mapping: topics/callbacks-and-mapping.md
- Async and Batching: topics/async-and-batching.md
@@ -26,6 +27,7 @@ nav:
- rigging.generator: api/generator.md
- rigging.model: api/model.md
- rigging.message: api/message.md
+ - rigging.prompt: api/prompt.md
- rigging.tool: api/tool.md
- rigging.data: api/data.md
- rigging.parsing: api/parsing.md
diff --git a/poetry.lock b/poetry.lock
index ca01952..b3dee5d 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -634,43 +634,43 @@ toml = ["tomli"]
[[package]]
name = "cryptography"
-version = "42.0.7"
+version = "42.0.8"
description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
optional = true
python-versions = ">=3.7"
files = [
- {file = "cryptography-42.0.7-cp37-abi3-macosx_10_12_universal2.whl", hash = "sha256:a987f840718078212fdf4504d0fd4c6effe34a7e4740378e59d47696e8dfb477"},
- {file = "cryptography-42.0.7-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:bd13b5e9b543532453de08bcdc3cc7cebec6f9883e886fd20a92f26940fd3e7a"},
- {file = "cryptography-42.0.7-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a79165431551042cc9d1d90e6145d5d0d3ab0f2d66326c201d9b0e7f5bf43604"},
- {file = "cryptography-42.0.7-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a47787a5e3649008a1102d3df55424e86606c9bae6fb77ac59afe06d234605f8"},
- {file = "cryptography-42.0.7-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:02c0eee2d7133bdbbc5e24441258d5d2244beb31da5ed19fbb80315f4bbbff55"},
- {file = "cryptography-42.0.7-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:5e44507bf8d14b36b8389b226665d597bc0f18ea035d75b4e53c7b1ea84583cc"},
- {file = "cryptography-42.0.7-cp37-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:7f8b25fa616d8b846aef64b15c606bb0828dbc35faf90566eb139aa9cff67af2"},
- {file = "cryptography-42.0.7-cp37-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:93a3209f6bb2b33e725ed08ee0991b92976dfdcf4e8b38646540674fc7508e13"},
- {file = "cryptography-42.0.7-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:e6b8f1881dac458c34778d0a424ae5769de30544fc678eac51c1c8bb2183e9da"},
- {file = "cryptography-42.0.7-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:3de9a45d3b2b7d8088c3fbf1ed4395dfeff79d07842217b38df14ef09ce1d8d7"},
- {file = "cryptography-42.0.7-cp37-abi3-win32.whl", hash = "sha256:789caea816c6704f63f6241a519bfa347f72fbd67ba28d04636b7c6b7da94b0b"},
- {file = "cryptography-42.0.7-cp37-abi3-win_amd64.whl", hash = "sha256:8cb8ce7c3347fcf9446f201dc30e2d5a3c898d009126010cbd1f443f28b52678"},
- {file = "cryptography-42.0.7-cp39-abi3-macosx_10_12_universal2.whl", hash = "sha256:a3a5ac8b56fe37f3125e5b72b61dcde43283e5370827f5233893d461b7360cd4"},
- {file = "cryptography-42.0.7-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:779245e13b9a6638df14641d029add5dc17edbef6ec915688f3acb9e720a5858"},
- {file = "cryptography-42.0.7-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d563795db98b4cd57742a78a288cdbdc9daedac29f2239793071fe114f13785"},
- {file = "cryptography-42.0.7-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:31adb7d06fe4383226c3e963471f6837742889b3c4caa55aac20ad951bc8ffda"},
- {file = "cryptography-42.0.7-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:efd0bf5205240182e0f13bcaea41be4fdf5c22c5129fc7ced4a0282ac86998c9"},
- {file = "cryptography-42.0.7-cp39-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:a9bc127cdc4ecf87a5ea22a2556cab6c7eda2923f84e4f3cc588e8470ce4e42e"},
- {file = "cryptography-42.0.7-cp39-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:3577d029bc3f4827dd5bf8bf7710cac13527b470bbf1820a3f394adb38ed7d5f"},
- {file = "cryptography-42.0.7-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:2e47577f9b18723fa294b0ea9a17d5e53a227867a0a4904a1a076d1646d45ca1"},
- {file = "cryptography-42.0.7-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:1a58839984d9cb34c855197043eaae2c187d930ca6d644612843b4fe8513c886"},
- {file = "cryptography-42.0.7-cp39-abi3-win32.whl", hash = "sha256:e6b79d0adb01aae87e8a44c2b64bc3f3fe59515280e00fb6d57a7267a2583cda"},
- {file = "cryptography-42.0.7-cp39-abi3-win_amd64.whl", hash = "sha256:16268d46086bb8ad5bf0a2b5544d8a9ed87a0e33f5e77dd3c3301e63d941a83b"},
- {file = "cryptography-42.0.7-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:2954fccea107026512b15afb4aa664a5640cd0af630e2ee3962f2602693f0c82"},
- {file = "cryptography-42.0.7-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:362e7197754c231797ec45ee081f3088a27a47c6c01eff2ac83f60f85a50fe60"},
- {file = "cryptography-42.0.7-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:4f698edacf9c9e0371112792558d2f705b5645076cc0aaae02f816a0171770fd"},
- {file = "cryptography-42.0.7-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:5482e789294854c28237bba77c4c83be698be740e31a3ae5e879ee5444166582"},
- {file = "cryptography-42.0.7-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:e9b2a6309f14c0497f348d08a065d52f3020656f675819fc405fb63bbcd26562"},
- {file = "cryptography-42.0.7-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:d8e3098721b84392ee45af2dd554c947c32cc52f862b6a3ae982dbb90f577f14"},
- {file = "cryptography-42.0.7-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:c65f96dad14f8528a447414125e1fc8feb2ad5a272b8f68477abbcc1ea7d94b9"},
- {file = "cryptography-42.0.7-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:36017400817987670037fbb0324d71489b6ead6231c9604f8fc1f7d008087c68"},
- {file = "cryptography-42.0.7.tar.gz", hash = "sha256:ecbfbc00bf55888edda9868a4cf927205de8499e7fabe6c050322298382953f2"},
+ {file = "cryptography-42.0.8-cp37-abi3-macosx_10_12_universal2.whl", hash = "sha256:81d8a521705787afe7a18d5bfb47ea9d9cc068206270aad0b96a725022e18d2e"},
+ {file = "cryptography-42.0.8-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:961e61cefdcb06e0c6d7e3a1b22ebe8b996eb2bf50614e89384be54c48c6b63d"},
+ {file = "cryptography-42.0.8-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e3ec3672626e1b9e55afd0df6d774ff0e953452886e06e0f1eb7eb0c832e8902"},
+ {file = "cryptography-42.0.8-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e599b53fd95357d92304510fb7bda8523ed1f79ca98dce2f43c115950aa78801"},
+ {file = "cryptography-42.0.8-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:5226d5d21ab681f432a9c1cf8b658c0cb02533eece706b155e5fbd8a0cdd3949"},
+ {file = "cryptography-42.0.8-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:6b7c4f03ce01afd3b76cf69a5455caa9cfa3de8c8f493e0d3ab7d20611c8dae9"},
+ {file = "cryptography-42.0.8-cp37-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:2346b911eb349ab547076f47f2e035fc8ff2c02380a7cbbf8d87114fa0f1c583"},
+ {file = "cryptography-42.0.8-cp37-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:ad803773e9df0b92e0a817d22fd8a3675493f690b96130a5e24f1b8fabbea9c7"},
+ {file = "cryptography-42.0.8-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:2f66d9cd9147ee495a8374a45ca445819f8929a3efcd2e3df6428e46c3cbb10b"},
+ {file = "cryptography-42.0.8-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:d45b940883a03e19e944456a558b67a41160e367a719833c53de6911cabba2b7"},
+ {file = "cryptography-42.0.8-cp37-abi3-win32.whl", hash = "sha256:a0c5b2b0585b6af82d7e385f55a8bc568abff8923af147ee3c07bd8b42cda8b2"},
+ {file = "cryptography-42.0.8-cp37-abi3-win_amd64.whl", hash = "sha256:57080dee41209e556a9a4ce60d229244f7a66ef52750f813bfbe18959770cfba"},
+ {file = "cryptography-42.0.8-cp39-abi3-macosx_10_12_universal2.whl", hash = "sha256:dea567d1b0e8bc5764b9443858b673b734100c2871dc93163f58c46a97a83d28"},
+ {file = "cryptography-42.0.8-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c4783183f7cb757b73b2ae9aed6599b96338eb957233c58ca8f49a49cc32fd5e"},
+ {file = "cryptography-42.0.8-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0608251135d0e03111152e41f0cc2392d1e74e35703960d4190b2e0f4ca9c70"},
+ {file = "cryptography-42.0.8-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:dc0fdf6787f37b1c6b08e6dfc892d9d068b5bdb671198c72072828b80bd5fe4c"},
+ {file = "cryptography-42.0.8-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:9c0c1716c8447ee7dbf08d6db2e5c41c688544c61074b54fc4564196f55c25a7"},
+ {file = "cryptography-42.0.8-cp39-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:fff12c88a672ab9c9c1cf7b0c80e3ad9e2ebd9d828d955c126be4fd3e5578c9e"},
+ {file = "cryptography-42.0.8-cp39-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:cafb92b2bc622cd1aa6a1dce4b93307792633f4c5fe1f46c6b97cf67073ec961"},
+ {file = "cryptography-42.0.8-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:31f721658a29331f895a5a54e7e82075554ccfb8b163a18719d342f5ffe5ecb1"},
+ {file = "cryptography-42.0.8-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b297f90c5723d04bcc8265fc2a0f86d4ea2e0f7ab4b6994459548d3a6b992a14"},
+ {file = "cryptography-42.0.8-cp39-abi3-win32.whl", hash = "sha256:2f88d197e66c65be5e42cd72e5c18afbfae3f741742070e3019ac8f4ac57262c"},
+ {file = "cryptography-42.0.8-cp39-abi3-win_amd64.whl", hash = "sha256:fa76fbb7596cc5839320000cdd5d0955313696d9511debab7ee7278fc8b5c84a"},
+ {file = "cryptography-42.0.8-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:ba4f0a211697362e89ad822e667d8d340b4d8d55fae72cdd619389fb5912eefe"},
+ {file = "cryptography-42.0.8-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:81884c4d096c272f00aeb1f11cf62ccd39763581645b0812e99a91505fa48e0c"},
+ {file = "cryptography-42.0.8-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:c9bb2ae11bfbab395bdd072985abde58ea9860ed84e59dbc0463a5d0159f5b71"},
+ {file = "cryptography-42.0.8-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:7016f837e15b0a1c119d27ecd89b3515f01f90a8615ed5e9427e30d9cdbfed3d"},
+ {file = "cryptography-42.0.8-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5a94eccb2a81a309806027e1670a358b99b8fe8bfe9f8d329f27d72c094dde8c"},
+ {file = "cryptography-42.0.8-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:dec9b018df185f08483f294cae6ccac29e7a6e0678996587363dc352dc65c842"},
+ {file = "cryptography-42.0.8-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:343728aac38decfdeecf55ecab3264b015be68fc2816ca800db649607aeee648"},
+ {file = "cryptography-42.0.8-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:013629ae70b40af70c9a7a5db40abe5d9054e6f4380e50ce769947b73bf3caad"},
+ {file = "cryptography-42.0.8.tar.gz", hash = "sha256:8d09d05439ce7baa8e9e95b07ec5b6c886f548deb7e0f69ef25f64b3bce842f2"},
]
[package.dependencies]
@@ -1217,13 +1217,13 @@ socks = ["socksio (==1.*)"]
[[package]]
name = "huggingface-hub"
-version = "0.23.2"
+version = "0.23.3"
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
optional = false
python-versions = ">=3.8.0"
files = [
- {file = "huggingface_hub-0.23.2-py3-none-any.whl", hash = "sha256:48727a16e704d409c4bb5913613308499664f22a99743435dc3a13b23c485827"},
- {file = "huggingface_hub-0.23.2.tar.gz", hash = "sha256:f6829b62d5fdecb452a76fdbec620cba4c1573655a8d710c1df71735fd9edbd2"},
+ {file = "huggingface_hub-0.23.3-py3-none-any.whl", hash = "sha256:22222c41223f1b7c209ae5511d2d82907325a0e3cdbce5f66949d43c598ff3bc"},
+ {file = "huggingface_hub-0.23.3.tar.gz", hash = "sha256:1a1118a0b3dea3bab6c325d71be16f5ffe441d32f3ac7c348d6875911b694b5b"},
]
[package.dependencies]
@@ -2507,13 +2507,13 @@ files = [
[[package]]
name = "openai"
-version = "1.31.0"
+version = "1.31.1"
description = "The official Python library for the openai API"
optional = false
python-versions = ">=3.7.1"
files = [
- {file = "openai-1.31.0-py3-none-any.whl", hash = "sha256:82044ee3122113f2a468a1f308a8882324d09556ba5348687c535d3655ee331c"},
- {file = "openai-1.31.0.tar.gz", hash = "sha256:54ae0625b005d6a3b895db2b8438dae1059cffff0cd262a26e9015c13a29ab06"},
+ {file = "openai-1.31.1-py3-none-any.whl", hash = "sha256:a746cf070798a4048cfea00b0fc7cb9760ee7ead5a08c48115b914d1afbd1b53"},
+ {file = "openai-1.31.1.tar.gz", hash = "sha256:a15266827de20f407d4bf9837030b168074b5b29acd54f10bb38d5f53e95f083"},
]
[package.dependencies]
@@ -5093,6 +5093,17 @@ files = [
numpy = "*"
torch = "2.3.0"
+[[package]]
+name = "xmltodict"
+version = "0.13.0"
+description = "Makes working with XML feel like you are working with JSON"
+optional = false
+python-versions = ">=3.4"
+files = [
+ {file = "xmltodict-0.13.0-py2.py3-none-any.whl", hash = "sha256:aa89e8fd76320154a40d19a0df04a4695fb9dc5ba977cbb68ab3e4eb225e7852"},
+ {file = "xmltodict-0.13.0.tar.gz", hash = "sha256:341595a488e3e01a85a9d8911d8912fd922ede5fecc4dce437eb4b6c8d037e56"},
+]
+
[[package]]
name = "yarl"
version = "1.9.4"
@@ -5218,4 +5229,4 @@ examples = ["asyncssh", "click", "httpx", "types-requests"]
[metadata]
lock-version = "2.0"
python-versions = "^3.9"
-content-hash = "bdbac14ad41e77ecb3d7a47f5a8892fad141aba1dc7cc14d9b5cd80a25527c5d"
+content-hash = "a6b1bec0675b4f10e46a6cb31c76dc33e12c3f5c3bd9c6f5677083886a6274de"
diff --git a/pyproject.toml b/pyproject.toml
index 28d1219..a0dcd2a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -28,6 +28,7 @@ asyncssh = { version = "^2.14.2", optional = true }
types-requests = { version = "^2.32.0.20240523", optional = true }
click = { version = "^8.1.7", optional = true }
httpx = { version = "^0.27.0", optional = true }
+xmltodict = "^0.13.0"
[tool.poetry.extras]
examples = ["asyncssh", "types-requests", "click", "httpx"]
diff --git a/rigging/__init__.py b/rigging/__init__.py
index 84250ab..7c43310 100644
--- a/rigging/__init__.py
+++ b/rigging/__init__.py
@@ -22,6 +22,7 @@
)
from rigging.message import Message, MessageDict, Messages
from rigging.model import Model, attr, element, make_primitive, wrapped
+from rigging.prompt import Ctx, Prompt, prompt
from rigging.tool import Tool
__version__ = "1.3.0"
@@ -56,6 +57,9 @@
"chats_to_elastic_data",
"flatten_chats",
"unflatten_chats",
+ "prompt",
+ "Prompt",
+ "Ctx",
]
from loguru import logger
diff --git a/rigging/chat.py b/rigging/chat.py
index 1af5ba1..7655b64 100644
--- a/rigging/chat.py
+++ b/rigging/chat.py
@@ -28,6 +28,8 @@
from elasticsearch import AsyncElasticsearch
from rigging.data import ElasticOpType
+ from rigging.prompt import Prompt
+ from rigging.util import P, R
DEFAULT_MAX_ROUNDS = 5
"""Maximum number of internal callback rounds to attempt during generation before giving up."""
@@ -203,7 +205,16 @@ def continue_(self, messages: t.Sequence[Message] | t.Sequence[MessageDict] | Me
return self.fork(messages, include_all=True)
def clone(self, *, only_messages: bool = False) -> Chat:
- """Creates a deep copy of the chat."""
+ """
+ Creates a deep copy of the chat.
+
+ Args:
+ only_messages: If True, only the messages will be cloned.
+ If False (default), the entire chat object will be cloned.
+
+ Returns:
+ A new instance of Chat.
+ """
new = Chat(
[m.model_copy() for m in self.messages],
[m.model_copy() for m in self.generated],
@@ -211,6 +222,11 @@ def clone(self, *, only_messages: bool = False) -> Chat:
)
if not only_messages:
new.metadata = deepcopy(self.metadata)
+ new.params = self.params.model_copy() if self.params is not None else None
+ new.stop_reason = self.stop_reason
+ new.usage = self.usage.model_copy() if self.usage is not None else None
+ new.extra = deepcopy(self.extra)
+ new.failed = self.failed
return new
def apply(self, **kwargs: str) -> Chat:
@@ -371,7 +387,7 @@ async def to_elastic(
"""
from rigging.data import chats_to_elastic
- return chats_to_elastic(self, index, client, op_type=op_type, create_index=create_index, **kwargs)
+ return await chats_to_elastic(self, index, client, op_type=op_type, create_index=create_index, **kwargs)
# Callbacks
@@ -564,7 +580,12 @@ def clone(self, *, only_messages: bool = False) -> ChatPipeline:
Returns:
A new instance of `ChatPipeline` that is a clone of the current instance.
"""
- new = ChatPipeline(self.generator, [], params=self.params, watch_callbacks=self.watch_callbacks)
+ new = ChatPipeline(
+ self.generator,
+ [],
+ params=self.params.model_copy() if self.params is not None else None,
+ watch_callbacks=self.watch_callbacks,
+ )
new.chat = self.chat.clone()
if not only_messages:
new.until_callbacks = self.until_callbacks.copy()
@@ -573,8 +594,8 @@ def clone(self, *, only_messages: bool = False) -> ChatPipeline:
new.inject_tool_prompt = self.inject_tool_prompt
new.force_tool = self.force_tool
new.metadata = deepcopy(self.metadata)
- new.then_chat_callbacks = self.then_chat_callbacks.copy()
- new.map_chat_callbacks = self.map_chat_callbacks.copy()
+ new.then_callbacks = self.then_callbacks.copy()
+ new.map_callbacks = self.map_callbacks.copy()
return new
def meta(self, **kwargs: t.Any) -> ChatPipeline:
@@ -1157,3 +1178,72 @@ async def run_batch(
chats = [s.chat for s in states if s.chat is not None]
return await self._post_run(chats)
+
+ # Generator iteration
+
+ async def arun_over(
+ self,
+ *generators: Generator | str,
+ include_original: bool = True,
+ skip_failed: bool = False,
+ include_failed: bool = False,
+ ) -> ChatList:
+ """
+ Executes the generation process across multiple generators.
+
+ For each generator, this pending chat is cloned and the generator is replaced
+ before the run call. All callbacks and parameters are preserved.
+
+ Parameters:
+ *generators: A sequence of generators to be used for the generation process.
+ include_original: Whether to include the original generator in the list of runs.
+ skip_failed: Enable to ignore any max rounds errors and return only successful chats.
+ include_failed: Enable to ignore max rounds errors and return both
+ successful and failed chats.
+
+ Returns:
+ A list of generatated Chats.
+ """
+ if skip_failed and include_failed:
+ raise ValueError("Cannot use both skip_failed and include_failed")
+
+ _generators: list[Generator] = [g if isinstance(g, Generator) else get_generator(g) for g in generators]
+ if include_original:
+ _generators.append(self.generator)
+
+ coros: list[t.Coroutine[t.Any, t.Any, Chat]] = []
+ for generator in _generators:
+ sub = self.clone()
+ sub.generator = generator
+ coros.append(sub.run(allow_failed=skip_failed or include_failed))
+
+ chats = await asyncio.gather(*coros)
+
+ if skip_failed:
+ chats = [c for c in chats if not c.failed]
+
+ return ChatList(chats)
+
+ # Prompt functions
+
+ def prompt(self, func: t.Callable[P, t.Coroutine[None, None, R]]) -> Prompt[P, R]:
+ """
+ Decorator to convert a function into a prompt bound to this pipeline.
+
+ See [rigging.prompt.prompt][] for more information.
+
+ Args:
+ func: The function to be converted into a prompt.
+
+ Returns:
+ The prompt.
+ """
+ from rigging.prompt import prompt
+
+ return prompt(func, pipeline=self)
+
+ async def run_prompt(self, prompt: Prompt[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
+ return await prompt.run(*args, pipeline=self, **kwargs)
+
+ async def run_prompt_many(self, prompt: Prompt[P, R], count: int, *args: P.args, **kwargs: P.kwargs) -> list[R]:
+ return await prompt.run_many(count, *args, pipeline=self, **kwargs)
diff --git a/rigging/completion.py b/rigging/completion.py
index 5d6fa8c..1028846 100644
--- a/rigging/completion.py
+++ b/rigging/completion.py
@@ -117,7 +117,6 @@ def restart(self, *, generator: t.Optional[Generator] = None, include_all: bool
generator: The generator to use for the restarted completion. Otherwise
the generator from the original CompletionPipeline will be used.
include_all: Whether to include the generation before the next round.
-
Returns:
The restarted completion.
@@ -153,6 +152,11 @@ def clone(self, *, only_messages: bool = False) -> Completion:
new = Completion(self.text, self.generated, self.generator)
if not only_messages:
new.metadata = deepcopy(self.metadata)
+ new.stop_reason = self.stop_reason
+ new.usage = self.usage.model_copy() if self.usage is not None else self.usage
+ new.extra = deepcopy(self.extra)
+ new.params = self.params.model_copy() if self.params is not None else self.params
+ new.failed = self.failed
return new
def meta(self, **kwargs: t.Any) -> Completion:
@@ -388,12 +392,18 @@ def clone(self, *, only_text: bool = False) -> CompletionPipeline:
Returns:
A new instance of `CompletionPipeline` that is a clone of the current instance.
"""
- new = CompletionPipeline(self.generator, self.text, params=self.params, watch_callbacks=self.watch_callbacks)
+ new = CompletionPipeline(
+ self.generator,
+ self.text,
+ params=self.params.model_copy() if self.params is not None else None,
+ watch_callbacks=self.watch_callbacks,
+ )
if not only_text:
new.until_callbacks = self.until_callbacks.copy()
new.until_types = self.until_types.copy()
new.metadata = deepcopy(self.metadata)
new.then_callbacks = self.then_callbacks.copy()
+ new.map_callbacks = self.map_callbacks.copy()
return new
def meta(self, **kwargs: t.Any) -> CompletionPipeline:
diff --git a/rigging/generator/base.py b/rigging/generator/base.py
index 19f2a9e..f8fe50a 100644
--- a/rigging/generator/base.py
+++ b/rigging/generator/base.py
@@ -13,6 +13,8 @@
if t.TYPE_CHECKING:
from rigging.chat import ChatPipeline, WatchChatCallback
from rigging.completion import CompletionPipeline, WatchCompletionCallback
+ from rigging.prompt import Prompt
+ from rigging.util import P, R
WatchCallbacks = t.Union[WatchChatCallback, WatchCompletionCallback]
@@ -28,13 +30,11 @@ def __call__(self) -> type[Generator]:
g_providers: dict[str, type[Generator] | LazyGenerator] = {}
-# TODO: Ideally we flex this to support arbitrary
-# generator params, but we'll limit things
-# for now until we understand the use cases
-#
# TODO: We also would like to support N-style
# parallel generation eventually -> need to
# update our interfaces to support that
+
+
class GenerateParams(BaseModel):
"""
Parameters for generating text using a language model.
@@ -252,7 +252,7 @@ def watch(self, *callbacks: WatchCallbacks, allow_duplicates: bool = False) -> G
allow_duplicates: Whether to allow (seemingly) duplicate callbacks to be added.
```
- def log(chats: list[Chat]) -> None:
+ async def log(chats: list[Chat]) -> None:
...
pipeline.watch(log).run()
@@ -398,6 +398,22 @@ def complete(self, text: str, params: GenerateParams | None = None) -> Completio
return CompletionPipeline(self, text, params=params, watch_callbacks=completion_watch_callbacks)
+ def prompt(self, func: t.Callable[P, t.Coroutine[None, None, R]]) -> Prompt[P, R]:
+ """
+ Decorator to convert a function into a prompt bound to this generator.
+
+ See [rigging.prompt.prompt][] for more information.
+
+ Args:
+ func: The function to be converted into a prompt.
+
+ Returns:
+ The prompt.
+ """
+ from rigging.prompt import prompt
+
+ return prompt(func, generator=self)
+
@t.overload
def chat(
diff --git a/rigging/message.py b/rigging/message.py
index a708dab..d0a26a2 100644
--- a/rigging/message.py
+++ b/rigging/message.py
@@ -133,7 +133,7 @@ def _remove_part(self, part: ParsedMessagePart) -> str:
def _add_part(self, part: ParsedMessagePart) -> None:
for existing in self.parts:
- if part.slice_ == existing.slice_ and isinstance(part.model, type(existing.model)):
+ if part.slice_ == existing.slice_ and part.model.xml_tags() == existing.model.xml_tags():
return # We clearly already have this part defined
if max(part.slice_.start, existing.slice_.start) < min(part.slice_.stop, existing.slice_.stop):
raise ValueError("Incoming part overlaps with an existing part")
@@ -368,7 +368,7 @@ def fit(cls, message: t.Union[Message, MessageDict, str]) -> Message:
"""Helper function to convert various common types to a Message object."""
if isinstance(message, str):
return cls(role="user", content=message)
- return cls(**message) if isinstance(message, dict) else message
+ return cls(**message) if isinstance(message, dict) else message.model_copy()
@classmethod
def apply_to_list(cls, messages: t.Sequence[Message], **kwargs: str) -> list[Message]:
diff --git a/rigging/model.py b/rigging/model.py
index 57abfdf..be77d5d 100644
--- a/rigging/model.py
+++ b/rigging/model.py
@@ -8,6 +8,7 @@
import typing as t
from xml.etree import ElementTree as ET
+import xmltodict # type: ignore
from pydantic import (
BeforeValidator,
SerializationInfo,
@@ -16,7 +17,6 @@
field_serializer,
field_validator,
)
-from pydantic.alias_generators import to_snake
from pydantic_xml import BaseXmlModel
from pydantic_xml import attr as attr
from pydantic_xml import element as element
@@ -24,6 +24,7 @@
from pydantic_xml.typedefs import EntityLocation, NsMap
from rigging.error import MissingModelError
+from rigging.util import escape_xml, to_xml_tag, unescape_xml
if t.TYPE_CHECKING:
from pydantic_xml.element import SearchMode # type: ignore [attr-defined]
@@ -42,25 +43,6 @@
BASIC_TYPES = [int, str, float, bool]
-def escape_xml(xml_string: str) -> str:
- prepared = re.sub(r"&(?!(?:amp|lt|gt|apos|quot);)", "&", xml_string)
-
- return prepared
-
-
-def unescape_xml(xml_string: str) -> str:
- # We only expect to use this in our "simple"
- # models, but I'd like a better long-term solution
-
- unescaped = re.sub(r"&", "&", xml_string)
- unescaped = re.sub(r"<", "<", unescaped)
- unescaped = re.sub(r">", ">", unescaped)
- unescaped = re.sub(r"'", "'", unescaped)
- unescaped = re.sub(r""", '"', unescaped)
-
- return unescaped
-
-
class XmlTagDescriptor:
def __get__(self, _: t.Any, owner: t.Any) -> str:
mro_iter = iter(owner.mro())
@@ -79,7 +61,7 @@ def __get__(self, _: t.Any, owner: t.Any) -> str:
if "[" in cls.__name__:
return t.cast(str, parent.__xml_tag__)
- return to_snake(cls.__name__).replace("_", "-")
+ return to_xml_tag(cls.__name__)
class Model(BaseXmlModel):
@@ -118,6 +100,8 @@ def to_pretty_xml(self) -> str:
pretty_encoded_xml = ET.tostring(tree, short_empty_elements=False).decode()
if self.__class__.is_simple():
+ # We only expect to use this in our "simple"
+ # models, but I'd like a better long-term solution
return unescape_xml(pretty_encoded_xml)
else:
return pretty_encoded_xml
@@ -167,12 +151,18 @@ def xml_example(cls) -> str:
Models should typically override this method to provide a more complex example.
- By default, this method just returns the XML tags for the class.
+ By default, this method returns a hollow XML scaffold one layer deep.
Returns:
A string containing the XML representation of the class.
"""
- return cls.xml_tags()
+ schema = cls.model_json_schema()
+ properties = schema["properties"]
+ structure = {cls.__xml_tag__: {field: None for field in properties}}
+ xml_string = xmltodict.unparse(
+ structure, pretty=True, full_document=False, indent=" ", short_empty_elements=True
+ )
+ return t.cast(str, xml_string) # Bad type hints in xmltodict
@classmethod
def ensure_valid(cls) -> None:
@@ -299,7 +289,7 @@ class Primitive(Model, t.Generic[PrimitiveT]):
def make_primitive(
name: str,
- type_: PrimitiveT = str,
+ type_: type[PrimitiveT] = str, # type: ignore [assignment]
*,
tag: str | None = None,
doc: str | None = None,
@@ -328,7 +318,7 @@ def _validate(value: str) -> str:
return create_model(
name,
- __base__=Primitive[type_],
+ __base__=Primitive[type_], # type: ignore
__doc__=doc,
__cls_kwargs__={"tag": tag},
content=(t.Annotated[type_, BeforeValidator(lambda x: x.strip() if isinstance(x, str) else x)], ...),
diff --git a/rigging/prompt.py b/rigging/prompt.py
index af8f013..3fa4786 100644
--- a/rigging/prompt.py
+++ b/rigging/prompt.py
@@ -1,182 +1,788 @@
from __future__ import annotations
+import asyncio
+import dataclasses
import inspect
+import re
import typing as t
-from jinja2 import Environment, meta
-from pydantic import BaseModel, computed_field, model_validator
-from typing_extensions import ParamSpec
+from jinja2 import Environment, StrictUndefined, meta
+from pydantic import ValidationError
+from rigging.chat import Chat
from rigging.generator.base import GenerateParams, Generator, get_generator
-from rigging.model import Model
+from rigging.message import Message
+from rigging.model import Model, SystemErrorModel, ValidationErrorModel, make_primitive
+from rigging.util import P, R, escape_xml, to_snake, to_xml_tag
if t.TYPE_CHECKING:
- from rigging.chat import ChatPipeline
+ from rigging.chat import ChatPipeline, WatchChatCallback
-DEFAULT_DOC = "You will convert the following inputs to outputs."
+DEFAULT_DOC = "Convert the following inputs to outputs ({func_name})."
+"""Default docstring if none is provided to a prompt function."""
+
+DEFAULT_MAX_ROUNDS = 3
+"""Default maximum number of rounds for a prompt to run until outputs are parsed."""
+
+# Annotation
+
+
+@dataclasses.dataclass
+class Ctx:
+ """
+ Used in type annotations to provide additional context for the prompt construction.
+
+ You can use this annotation on inputs and ouputs to prompt functions.
+
+ ```
+ tag_override = Annotated[str, Ctx(tag="custom_tag", ...)]
+ ```
+ """
+
+ tag: str | None = None
+ prefix: str | None = None
+ example: str | None = None
-P = ParamSpec("P")
-R = t.TypeVar("R")
-BASIC_TYPES = [int, float, str, bool, list, dict, set, tuple, type(None)]
# Utilities
-def get_undefined_values(template: str) -> set[str]:
+def unwrap_annotated(annotation: t.Any) -> tuple[t.Any, t.Optional[Ctx]]:
+ if t.get_origin(annotation) is t.Annotated:
+ base_type, *meta = t.get_args(annotation)
+ for m in meta:
+ if isinstance(m, Ctx):
+ return base_type, m
+ return base_type, None
+ return annotation, None
+
+
+def get_undeclared_variables(template: str) -> set[str]:
env = Environment()
parsed_template = env.parse(template)
return meta.find_undeclared_variables(parsed_template)
-def format_parameter(param: inspect.Parameter, value: t.Any) -> str:
- name = param.name
+def make_parameter(
+ annotation: t.Any, *, name: str = "", kind: inspect._ParameterKind = inspect.Parameter.VAR_KEYWORD
+) -> inspect.Parameter:
+ return inspect.Parameter(name=name, kind=kind, annotation=annotation)
+
+
+# Function Inputs
+
+
+@dataclasses.dataclass
+class Input:
+ name: str
+ context: Ctx
+
+ @property
+ def tag(self) -> str:
+ return self.context.tag or to_xml_tag(self.name)
+
+ def _prefix(self, xml: str) -> str:
+ if self.context.prefix:
+ return f"{self.context.prefix}\n{xml}"
+ return xml
+
+ def to_str(self, value: t.Any) -> str:
+ raise NotImplementedError
+
+ def to_xml(self, value: t.Any) -> str:
+ value_str = self.to_str(value)
+ if "\n" in value_str:
+ value_str = f"\n{value_str}\n"
+ return self._prefix(f"<{self.tag}>{escape_xml(value_str)}{self.tag}>")
+
+
+@dataclasses.dataclass
+class BasicInput(Input):
+ def to_str(self, value: t.Any) -> str:
+ if not isinstance(value, (int, float, str, bool)):
+ raise ValueError(f"Value must be a basic type, got: {type(value)}")
- if isinstance(value, str):
- if "\n" in value:
- value = f"\n{value.strip()}\n"
- return f"<{name}>{value}{name}>"
+ return str(value)
- if isinstance(value, (int, float)):
- return f"<{name}>{value}{name}>"
- if isinstance(value, bool):
- return f"<{name}>{'true' if value else 'false'}{name}>"
+@dataclasses.dataclass
+class ModelInput(Input):
+ def to_str(self, value: t.Any) -> str:
+ if not isinstance(value, Model):
+ raise ValueError(f"Value must be a Model instance, got: {type(value)}")
- if isinstance(value, Model):
return value.to_pretty_xml()
- if isinstance(value, (list, set)):
- type_args = t.get_args(param.annotation)
+ def to_xml(self, value: t.Any) -> str:
+ return self._prefix(self.to_str(value))
+
+
+@dataclasses.dataclass
+class ListInput(Input):
+ interior: Input
+
+ def to_str(self, value: list[t.Any]) -> str:
+ return "\n\n".join(self.interior.to_str(v) for v in value)
+
- xml = f"<{name}>\n"
- for item in value:
- pass
+@dataclasses.dataclass
+class DictInput(Input):
+ interior: Input
- raise ValueError(f"Unsupported parameter: {param}: '{value}'")
+ def to_str(self, value: t.Any) -> str:
+ if not isinstance(value, dict):
+ raise ValueError(f"Value must be a dictionary, got: {type(value)}")
+ if not all(isinstance(k, str) for k in value.keys()):
+ raise ValueError("Dictionary keys must be strings")
+ return "\n".join(f"<{k}>{self.interior.to_str(v)}{k}>" for k, v in value.items())
-def check_valid_function(func: t.Callable[P, t.Coroutine[None, None, R]]) -> None:
- signature = inspect.signature(func)
- for param in signature.parameters.values():
- error_name = f"{func.__name__}({param})"
+def parse_parameter(param: inspect.Parameter, error_name: str) -> Input:
+ if param.kind not in (
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ inspect.Parameter.KEYWORD_ONLY,
+ ):
+ raise TypeError(f"Parameters must be positional or keyword {error_name}")
- if param.kind not in (
- inspect.Parameter.POSITIONAL_OR_KEYWORD,
- inspect.Parameter.KEYWORD_ONLY,
- ):
- raise TypeError(f"Parameters must be positional or keyword {error_name}")
+ if param.annotation in [None, inspect.Parameter.empty]:
+ raise TypeError(f"All parameters require type annotations {error_name}")
- if param.annotation in [None, inspect.Parameter.empty]:
- raise TypeError(f"All parameters require type annotations {error_name}")
+ annotation, context = unwrap_annotated(param.annotation)
- origin = t.get_origin(param.annotation)
- if origin is not None:
- # Check for a dict[str, str]
- if origin is dict:
- if param.annotation.__args__[0] != (str, str):
- raise TypeError(f"Dicts must have str keys {error_name}")
- raise TypeError(f"Parameters cannot be generic, lists, sets, or dicts {error_name}")
+ if annotation in [int, float, str, bool]:
+ return BasicInput(param.name, context or Ctx())
- if param.annotation in [int, float, str, bool]:
- continue
+ if inspect.isclass(annotation) and issubclass(annotation, Model):
+ return ModelInput(param.name, context or Ctx())
- if issubclass(param.annotation, Model):
- continue
+ if t.get_origin(annotation) is list:
+ if not param.name:
+ raise TypeError(f"Nested list parameters are not supported: {error_name}")
- raise TypeError(
- f"Invalid parameter type: {param.annotation}, must be one of int, bool, str, float or rg.Model ({func.__name__}#{param.name})"
+ args = t.get_args(annotation)
+ if not args:
+ raise TypeError(f"List param must be fully typed: {error_name}")
+
+ arg_type, arg_context = unwrap_annotated(args[0])
+ return ListInput(
+ param.name, arg_context or context or Ctx(), parse_parameter(make_parameter(arg_type), error_name)
)
- if signature.return_annotation in [None, inspect.Parameter.empty]:
- raise TypeError(f"Return type annotation is required ({func.__name__})")
+ elif t.get_origin(annotation) is dict:
+ if not param.name:
+ raise TypeError(f"Nested dict parameters are not supported: {error_name}")
+
+ args = t.get_args(annotation)
+ if not args or len(args) != 2:
+ raise TypeError(f"Dict param must be fully typed: {error_name}")
+ if args[0] is not str:
+ raise TypeError(f"Dict param keys must be strings: {error_name}")
+
+ return DictInput(param.name, context or Ctx(), parse_parameter(make_parameter(args[1]), error_name))
+
+ raise TypeError(f"Unsupported parameter type: {error_name}")
+
+
+# Function Outputs
+
+
+@dataclasses.dataclass
+class Output:
+ id: str
+ context: Ctx
+
+ @property
+ def tag(self) -> str:
+ return self.context.tag or to_xml_tag(self.id)
+
+ def _prefix(self, xml: str) -> str:
+ if self.context.prefix:
+ return f"{self.context.prefix}\n{xml}"
+ return xml
+
+ def guidance(self) -> str:
+ return "Produce the following output:"
+
+ def to_format(self) -> str:
+ tag = self.context.tag or self.tag
+ return self._prefix(f"<{tag}>{escape_xml(self.context.example or '')}{tag}>")
+
+ def from_chat(self, chat: Chat) -> t.Any:
+ raise NotImplementedError
+
+
+@dataclasses.dataclass
+class ChatOutput(Output):
+ def from_chat(self, chat: Chat) -> t.Any:
+ return chat
+
+
+@dataclasses.dataclass
+class BasicOutput(Output):
+ type_: type[t.Any] # TODO: We should be able to scope this down
+
+ def from_chat(self, chat: Chat) -> t.Any:
+ Temp = make_primitive("Model", self.type_, tag=self.context.tag or self.tag)
+ return chat.last.parse(Temp).content
+
+
+@dataclasses.dataclass
+class BasicListOutput(BasicOutput):
+ def guidance(self) -> str:
+ return "Produce the following output for each item:"
+
+ def from_chat(self, chat: Chat) -> t.Any:
+ Model = make_primitive("Model", self.type_, tag=self.context.tag or self.tag)
+ return [m.content for m in chat.last.parse_set(Model)]
+
+
+@dataclasses.dataclass
+class ModelOutput(Output):
+ type_: type[Model]
+
+ def to_format(self) -> str:
+ return self.type_.xml_example()
+
+ def from_chat(self, chat: Chat) -> t.Any:
+ return chat.last.parse(self.type_)
+
+
+@dataclasses.dataclass
+class ModelListOutput(ModelOutput):
+ def guidance(self) -> str:
+ return "Produce the following output for each item:"
+
+ def from_chat(self, chat: Chat) -> t.Any:
+ return chat.last.parse_set(self.type_)
+
+
+@dataclasses.dataclass
+class TupleOutput(Output):
+ interiors: list[Output]
+
+ @property
+ def real_interiors(self) -> list[Output]:
+ return [i for i in self.interiors if not isinstance(i, ChatOutput)]
+
+ @property
+ def wrapped(self) -> bool:
+ # Handles cases where we are using a tuple just to
+ # capture a Chat along with a real output, in this
+ # case we should fall through for most of the work
+ #
+ # () -> tuple[Chat, ...]
+ return len(self.real_interiors) == 1
+
+ def guidance(self) -> str:
+ if self.wrapped:
+ return self.real_interiors[0].guidance()
+ return "Produce the following outputs:"
+
+ def to_format(self) -> str:
+ if self.wrapped:
+ return self.real_interiors[0].to_format()
+ return self._prefix("\n\n".join(i.to_format() for i in self.real_interiors))
+
+ def from_chat(self, chat: Chat) -> t.Any:
+ return tuple(i.from_chat(chat) for i in self.interiors)
+
+
+@dataclasses.dataclass
+class DataclassOutput(TupleOutput):
+ type_: type[t.Any]
- if not isinstance(signature.return_annotation, tuple):
- return
+ def from_chat(self, chat: Chat) -> t.Any:
+ return self.type_(*super().from_chat(chat))
-def build_template(func: t.Callable) -> str:
- docstring = func.__doc__ or DEFAULT_DOC
- docstring = inspect.cleandoc(docstring)
+def parse_output(annotation: t.Any, error_name: str, *, allow_nested: bool = True) -> Output:
+ from rigging.chat import Chat
- signature = inspect.signature(func)
+ if annotation in [None, inspect.Parameter.empty]:
+ raise TypeError(f"Return type annotation is required ({error_name})")
+
+ # Unwrap any annotated types
+ annotation, context = unwrap_annotated(annotation)
+
+ if annotation == Chat:
+ # Use a special subclass here -> args don't matter
+ return ChatOutput(id="chat", context=context or Ctx())
+
+ if annotation in [int, float, str, bool]:
+ return BasicOutput(id=annotation.__name__, context=context or Ctx(), type_=annotation)
+
+ if t.get_origin(annotation) is list:
+ if not allow_nested:
+ raise TypeError(f"Nested list outputs are not supported ({error_name})")
+
+ args = t.get_args(annotation)
+ if not args:
+ raise TypeError(f"List return type must be fully specified ({error_name})")
+
+ arg_type, arg_context = unwrap_annotated(args[0])
+
+ if arg_type in [int, float, str, bool]:
+ return BasicListOutput(id=arg_type.__name__, context=arg_context or context or Ctx(), type_=arg_type)
+
+ if inspect.isclass(arg_type) and issubclass(arg_type, Model):
+ return ModelListOutput(id=arg_type.__name__, context=arg_context or context or Ctx(), type_=arg_type)
+
+ if t.get_origin(annotation) is tuple:
+ if not allow_nested:
+ raise TypeError(f"Nested tuple outputs are not supported ({error_name})")
+
+ args = t.get_args(annotation)
+ if not args:
+ raise TypeError(f"Tuple return type must be fully specified ({error_name})")
+
+ tuple_interiors = [parse_output(arg, error_name, allow_nested=False) for arg in args]
+
+ if len({i.tag for i in tuple_interiors}) != len(tuple_interiors):
+ raise TypeError(
+ f"Tuple return annotations must have unique internal types\n"
+ "or use Annotated[..., Context(tag=...)] overrides to\n"
+ f"make them differentiable ({error_name})"
+ )
+
+ return TupleOutput(id="tuple", context=context or Ctx(), interiors=tuple_interiors)
+
+ if dataclasses.is_dataclass(annotation):
+ interior_annotations: list[t.Any] = []
+ for field in dataclasses.fields(annotation):
+ field_annotation, field_context = unwrap_annotated(field.type)
+ if field_context is None:
+ field_annotation = t.Annotated[field_annotation, Ctx(tag=to_xml_tag(field.name))]
+ interior_annotations.append(field_annotation)
+
+ dataclass_interiors: list[Output] = []
+ for field, field_annotation in zip(dataclasses.fields(annotation), interior_annotations):
+ interior = parse_output(field_annotation, f"{error_name}#{field.name}", allow_nested=False)
+ if interior is None:
+ raise TypeError(f"Dataclass field type is invalid ({error_name}#{field.name}")
+ dataclass_interiors.append(interior)
+
+ if len({i.tag for i in dataclass_interiors}) != len(dataclass_interiors):
+ raise TypeError(
+ f"Dataclass return annotations must have unique internal types\n"
+ "or use Annotated[..., Context(tag=...)] overrides to\n"
+ f"make them differentiable ({error_name})"
+ )
+
+ return DataclassOutput(
+ id=annotation.__name__, type_=annotation, context=context or Ctx(), interiors=dataclass_interiors
+ )
+
+ # This has to come after our list/tuple checks as they pass isclass
+ if inspect.isclass(annotation) and issubclass(annotation, Model):
+ return ModelOutput(id=annotation.__name__, context=context or Ctx(), type_=annotation)
+
+ raise TypeError(f"Unsupported return type: {error_name}")
# Prompt
-class Prompt(BaseModel, t.Generic[P, R]):
- _func: t.Callable[P, t.Coroutine[None, None, R]]
+@dataclasses.dataclass
+class Prompt(t.Generic[P, R]):
+ """
+ Prompts wrap hollow functions and create structured chat interfaces for
+ passing inputs into a ChatPipeline and parsing outputs.
+ """
+
+ func: t.Callable[P, t.Coroutine[None, None, R]]
+ """The function that the prompt wraps. This function should be a coroutine."""
+
+ attempt_recovery: bool = True
+ """Whether the prompt should attempt to recover from errors in output parsing."""
+ drop_dialog: bool = True
+ """When attempting recovery, whether to drop intermediate dialog while parsing was being resolved."""
+ max_rounds: int = DEFAULT_MAX_ROUNDS
+ """The maximum number of rounds the prompt should try to reparse outputs."""
+
+ inputs: list[Input] = dataclasses.field(default_factory=list)
+ """The structured input handlers for the prompt."""
+ output: Output = ChatOutput(id="chat", context=Ctx())
+ """The structured output handler for the prompt."""
+
+ watch_callbacks: list[WatchChatCallback] = dataclasses.field(default_factory=list)
+ """Callbacks to be passed any chats produced while executing this prompt."""
+ params: GenerateParams | None = None
+ """The parameters to be used when generating chats for this prompt."""
_generator_id: str | None = None
_generator: Generator | None = None
_pipeline: ChatPipeline | None = None
- _params: GenerateParams | None = None
- @model_validator(mode="after")
- def check_valid_function(self) -> Prompt[P, R]:
- check_valid_function(self._func)
- return self
+ def __post_init__(self) -> None:
+ signature = inspect.signature(self.func)
+ undeclared = get_undeclared_variables(self.docstring)
+
+ for param in signature.parameters.values():
+ if param.name in undeclared:
+ continue
+ error_name = f"{self.func.__name__}({param})"
+ self.inputs.append(parse_parameter(param, error_name))
+
+ if len({i.tag for i in self.inputs}) != len(self.inputs):
+ raise TypeError("All input parameters must have unique names/tags")
+
+ error_name = f"{self.func.__name__}() -> {signature.return_annotation}"
+ self.output = parse_output(signature.return_annotation, error_name)
+
+ @property
+ def docstring(self) -> str:
+ """The docstring for the prompt function."""
+ # Guidance is taken from https://github.com/outlines-dev/outlines/blob/main/outlines/prompts.py
+ docstring = self.func.__doc__ or DEFAULT_DOC.format(func_name=self.func.__name__)
+ docstring = inspect.cleandoc(docstring)
+ docstring = re.sub(r"(?![\r\n])(\b\s+)", " ", docstring)
+ return docstring
- @computed_field # type: ignore [misc]
@property
def template(self) -> str:
- return ""
+ """The dynamic jinja2 template for the prompt function."""
+ text = f"{self.docstring}\n"
+
+ for input_ in self.inputs:
+ text += "\n{{ " + to_snake(input_.tag) + " }}\n"
+
+ if self.output is None or isinstance(self.output, ChatOutput):
+ return text
+
+ text += f"\n{self.output.guidance()}\n"
+ text += f"\n{self.output.to_format()}\n"
+
+ return text
@property
def pipeline(self) -> ChatPipeline | None:
+ """If available, the resolved Chat Pipeline for the prompt."""
if self._pipeline is not None:
- return self._pipeline.with_(params=self._params)
+ return self._pipeline
+
+ if self._generator is None and self._generator_id is not None:
+ self._generator = get_generator(self._generator_id)
- if self._generator is None:
- if self._generator_id is None:
- raise ValueError(
- "You cannot execute this prompt ad-hoc. No pipeline, generator, or generator_id was provided."
+ if self._generator is not None:
+ self._pipeline = self._generator.chat()
+ return self._pipeline
+
+ return None
+
+ def _until_parsed(self, message: Message) -> tuple[bool, list[Message]]:
+ should_continue: bool = False
+ generated: list[Message] = [message]
+
+ if self.output is None or isinstance(self.output, ChatOutput):
+ return (should_continue, generated)
+
+ try:
+ # A bit weird, but we need from_chat to properly handle
+ # wrapping Chat output types inside lists/dataclasses
+ self.output.from_chat(Chat([], generated=[message]))
+ except ValidationError as e:
+ should_continue = True
+ generated.append(
+ Message.from_model(
+ ValidationErrorModel(content=str(e)),
+ suffix="Rewrite your entire message with all the required elements.",
)
+ )
+ except Exception as e:
+ should_continue = True
+ generated.append(
+ Message.from_model(
+ SystemErrorModel(content=str(e)),
+ suffix="Rewrite your entire message with all the required elements.",
+ )
+ )
+
+ return (should_continue, generated)
+
+ def clone(self, *, skip_callbacks: bool = False) -> Prompt[P, R]:
+ """
+ Creates a deep copy of this prompt.
+
+ Args:
+ skip_callbacks: Whether to skip copying the watch callbacks.
+
+ Returns:
+ A new instance of the prompt.
+ """
+ new = Prompt(
+ func=self.func,
+ _pipeline=self.pipeline,
+ params=self.params.model_copy() if self.params is not None else None,
+ attempt_recovery=self.attempt_recovery,
+ drop_dialog=self.drop_dialog,
+ max_rounds=self.max_rounds,
+ )
+ if not skip_callbacks:
+ new.watch_callbacks = self.watch_callbacks.copy()
+ return new
- self._generator = get_generator(self._generator_id)
+ def with_(self, params: t.Optional[GenerateParams] = None, **kwargs: t.Any) -> Prompt[P, R]:
+ """
+ Assign specific generation parameter overloads for this prompt.
- return self._generator.chat(params=self._params)
+ Note:
+ This will trigger a `clone` if overload params have already been set.
- def clone(self) -> Prompt[P, R]:
- return Prompt(_func=self._func, _pipeline=self.pipeline)
+ Args:
+ params: The parameters to set for the underlying chat pipeline.
+ **kwargs: An alternative way to pass parameters as keyword arguments.
- def with_(self, params: t.Optional[GenerateParams] = None, **kwargs: t.Any) -> Prompt[P, R]:
+ Returns:
+ Prompt with the updated params.
+ """
if params is None:
params = GenerateParams(**kwargs)
- if self._params is not None:
+ if self.params is not None:
new = self.clone()
- new._params = self._params.merge_with(params)
+ new.params = self.params.merge_with(params)
return new
self.params = params
return self
+ # We could put these params into the decorator, but it makes it
+ # less flexible when we want to build gateway interfaces into
+ # creating a prompt from other code.
+
+ def set_(
+ self, attempt_recovery: bool | None = None, drop_dialog: bool | None = None, max_rounds: int | None = None
+ ) -> Prompt[P, R]:
+ """
+ Helper to allow updates to the parsing configuration.
+
+ Args:
+ attempt_recovery: Whether the prompt should attempt to recover from errors in output parsing.
+ drop_dialog: When attempting recovery, whether to drop intermediate dialog while parsing was being resolved.
+ max_rounds: The maximum number of rounds the prompt should try to reparse outputs.
+
+ Returns:
+ The current instance of the chat.
+ """
+ self.attempt_recovery = attempt_recovery or self.attempt_recovery
+ self.drop_dialog = drop_dialog or self.drop_dialog
+ self.max_rounds = max_rounds or self.max_rounds
+ return self
+
+ def watch(self, *callbacks: WatchChatCallback) -> Prompt[P, R]:
+ """
+ Registers a callback to monitor any chats produced for this prompt
+
+ Args:
+ *callbacks: The callback functions to be executed.
+
+ ```
+ async def log(chats: list[Chat]) -> None:
+ ...
+
+ @rg.prompt().watch(log)
+ async def summarize(text: str) -> str:
+ ...
+
+ summarize(...)
+ ```
+
+ Returns:
+ The current instance of the chat.
+ """
+ for callback in callbacks:
+ if callback not in self.watch_callbacks:
+ self.watch_callbacks.append(callback)
+ return self
+
def render(self, *args: P.args, **kwargs: P.kwargs) -> str:
- pass
+ """
+ Pass the arguments to the jinja2 template and render the full prompt.
+ """
+ env = Environment(
+ trim_blocks=True,
+ lstrip_blocks=True,
+ keep_trailing_newline=True,
+ undefined=StrictUndefined,
+ )
+ jinja_template = env.from_string(self.template)
+
+ signature = inspect.signature(self.func)
+ bound_args = signature.bind(*args, **kwargs)
+ bound_args.apply_defaults()
+
+ for input_ in self.inputs:
+ bound_args.arguments[to_snake(input_.tag)] = input_.to_xml(bound_args.arguments[input_.name])
+
+ return jinja_template.render(**bound_args.arguments)
+
+ def process(self, chat: Chat) -> R:
+ """
+ Attempt to parse the output from a chat into the expected return type.
+ """
+ return self.output.from_chat(chat) # type: ignore
+
+ async def run(self, *args: P.args, pipeline: ChatPipeline | None = None, **kwargs: P.kwargs) -> R:
+ """
+ Use the prompt to run the function with the provided arguments and return the output.
+
+ Args:
+ *args: The positional arguments for the prompt function.
+ pipeline: An optional pipeline to use for the prompt.
+ **kwargs: The keyword arguments for the prompt function.
+
+ Returns:
+ The output of the prompt function.
+ """
+ return (await self.run_many(1, *args, pipeline=pipeline, **kwargs))[0]
+
+ async def run_many(
+ self, count: int, *args: P.args, pipeline: ChatPipeline | None = None, **kwargs: P.kwargs
+ ) -> list[R]:
+ """
+ Use the prompt to run the function multiple times with the provided arguments and return the output.
+
+ Args:
+ count: The number of times to run the prompt.
+ *args: The positional arguments for the prompt function.
+ pipeline: An optional pipeline to use for the prompt.
+ **kwargs: The keyword arguments for the prompt function.
+
+ Returns:
+ The outputs of the prompt function.
+ """
+ pipeline = pipeline or self.pipeline
+ if pipeline is None:
+ raise RuntimeError(
+ "Prompt cannot be executed as a standalone function without being assigned a pipeline or generator"
+ )
+
+ content = self.render(*args, **kwargs)
+ pipeline = (
+ pipeline.fork(content)
+ .until(
+ self._until_parsed,
+ attempt_recovery=self.attempt_recovery,
+ drop_dialog=self.drop_dialog,
+ max_rounds=self.max_rounds,
+ )
+ .with_(self.params)
+ )
+ chats = await pipeline.run_many(count)
- async def run(self, *args: P.args, **kwargs: P.kwargs) -> R:
- pass
+ coros = [watch(chats) for watch in set(self.watch_callbacks + pipeline.watch_callbacks)]
+ await asyncio.gather(*coros)
- async def run_many(self, *args: P.args, **kwargs: P.kwargs) -> list[R]:
- pass
+ return [self.process(chat) for chat in chats]
__call__ = run
+# Decorator
+
+
+@t.overload
def prompt(
- *, pipeline: ChatPipeline | None = None, generator: Generator | None = None, generator_id: str | None = None
+ func: None = None,
+ /,
+ *,
+ pipeline: ChatPipeline | None = None,
+ generator: Generator | None = None,
+ generator_id: str | None = None,
) -> t.Callable[[t.Callable[P, t.Coroutine[None, None, R]]], Prompt[P, R]]:
- if sum(arg is not None for arg in (pipeline, generator, generator_id)) > 1:
- raise ValueError("Only one of pipeline, generator, or generator_id can be provided")
+ ...
- def decorator(func: t.Callable[P, t.Coroutine[None, None, R]]) -> Prompt[P, R]:
- return Prompt[P, R](_func=func, _generator_id=generator_id, _pipeline=pipeline, _generator=generator)
- return decorator
+@t.overload
+def prompt(
+ func: t.Callable[P, t.Coroutine[None, None, R]],
+ /,
+ *,
+ pipeline: ChatPipeline | None = None,
+ generator: Generator | None = None,
+ generator_id: str | None = None,
+) -> Prompt[P, R]:
+ ...
+
+def prompt(
+ func: t.Callable[P, t.Coroutine[None, None, R]] | None = None,
+ /,
+ *,
+ pipeline: ChatPipeline | None = None,
+ generator: Generator | None = None,
+ generator_id: str | None = None,
+) -> t.Callable[[t.Callable[P, t.Coroutine[None, None, R]]], Prompt[P, R]] | Prompt[P, R]:
+ """
+ Convert a hollow function into a Prompt, which can be called directly or passed a
+ chat pipeline to execute the function and parse the outputs.
+
+ ```
+ from dataclasses import dataclass
+ import rigging as rg
+
+ @dataclass
+ class ExplainedJoke:
+ chat: rg.Chat
+ setup: str
+ punchline: str
+ explanation: str
+
+ @rg.prompt(generator_id="gpt-3.5-turbo")
+ async def write_joke(topic: str) -> ExplainedJoke:
+ \"""Write a joke.\"""
+ ...
+
+ await write_joke("programming")
+
+ Note:
+ A docstring is not required, but this can be used to provide guidance to the model, or
+ even handle any number of input transormations. Any input parameter which is not
+ handled inside the docstring will be automatically added and formatted internally.
+
+ Note:
+ Output parameters can be basic types, dataclasses, rigging models, lists, or tuples.
+ Internal inspection will attempt to ensure your output types are valid, but there is
+ no guarantee of complete coverage/safety. It's recommended to check
+ [rigging.prompt.Prompt.template][] to inspect the generated jinja2 template.
+
+ Note:
+ If you annotate the return value of the function as a [rigging.chat.Chat][] object,
+ then no output parsing will take place and you can parse objects out manually.
+
+ You can also use Chat in any number of type annotation inside tuples or dataclasses.
+ All instances will be filled with the final chat object transparently.
+
+ Note:
+ All input parameters and output types can be annotated with the [rigging.prompt.Ctx][] annotation
+ to provide additional context for the prompt. This can be used to override the xml tag, provide
+ a prefix string, or example content which will be placed inside output xml tags.
+
+ In the case of output parameters, especially in tuples, you might have xml tag collisions
+ between the same basic types. Manually annotating xml tags with [rigging.prompt.Ctx][] is
+ recommended.
+
+ Args:
+ func: The function to convert into a prompt.
+ pipeline: An optional pipeline to use for the prompt.
+ generator: An optional generator to use for the prompt.
+ generator_id: An optional generator id to use for the prompt.
+
+ Returns:
+ A prompt instance or a function that can be used to create a prompt.
+ """
+ if sum(arg is not None for arg in (pipeline, generator, generator_id)) > 1:
+ raise ValueError("Only one of pipeline, generator, or generator_id can be provided")
+
+ def make_prompt(func: t.Callable[P, t.Coroutine[None, None, R]]) -> Prompt[P, R]:
+ return Prompt[P, R](
+ func=func,
+ _generator_id=generator_id,
+ _pipeline=pipeline,
+ _generator=generator,
+ )
-@prompt()
-async def testing() -> None:
- pass
+ if func is not None:
+ return make_prompt(func)
+ return make_prompt
diff --git a/rigging/util.py b/rigging/util.py
new file mode 100644
index 0000000..9137788
--- /dev/null
+++ b/rigging/util.py
@@ -0,0 +1,32 @@
+import re
+import typing as t
+
+from pydantic import alias_generators
+from typing_extensions import ParamSpec
+
+P = ParamSpec("P")
+R = t.TypeVar("R")
+
+
+def escape_xml(xml_string: str) -> str:
+ prepared = re.sub(r"&(?!(?:amp|lt|gt|apos|quot);)", "&", xml_string)
+
+ return prepared
+
+
+def unescape_xml(xml_string: str) -> str:
+ unescaped = re.sub(r"&", "&", xml_string)
+ unescaped = re.sub(r"<", "<", unescaped)
+ unescaped = re.sub(r">", ">", unescaped)
+ unescaped = re.sub(r"'", "'", unescaped)
+ unescaped = re.sub(r""", '"', unescaped)
+
+ return unescaped
+
+
+def to_snake(text: str) -> str:
+ return alias_generators.to_snake(text)
+
+
+def to_xml_tag(text: str) -> str:
+ return to_snake(text).replace("_", "-").strip("-")
diff --git a/tests/test_prompt.py b/tests/test_prompt.py
new file mode 100644
index 0000000..cc12e64
--- /dev/null
+++ b/tests/test_prompt.py
@@ -0,0 +1,239 @@
+from dataclasses import dataclass
+from textwrap import dedent
+from typing import Annotated
+
+import pytest
+
+import rigging as rg
+from rigging.chat import Chat
+
+# mypy: disable-error-code=empty-body
+
+
+def test_prompt_render_docstring_parse() -> None:
+ @rg.prompt
+ async def foo(name: str) -> str:
+ """Say hello."""
+ ...
+
+ assert foo.docstring == "Say hello."
+
+ @rg.prompt
+ async def bar(name: str) -> str:
+ """
+ Say hello."""
+ ...
+
+ assert bar.docstring == "Say hello."
+
+ @rg.prompt
+ async def baz(name: str) -> str:
+ """
+ Say \
+ hello.
+
+ """
+ ...
+
+ assert baz.docstring == "Say hello."
+
+
+def test_basic_prompt_render() -> None:
+ @rg.prompt
+ async def hello(name: str) -> str:
+ """Say hello."""
+ ...
+
+ rendered = hello.render("Alice")
+ assert rendered == dedent(
+ """\
+ Say hello.
+
+ Alice
+
+ Produce the following output:
+
+
+ """
+ )
+
+
+def test_prompt_render_with_docstring_variables() -> None:
+ @rg.prompt
+ async def greet(name: str, greeting: str = "Hello") -> str:
+ """Say '{{ greeting }}' to {{ name }}."""
+ ...
+
+ rendered = greet.render("Bob")
+ assert rendered == dedent(
+ """\
+ Say 'Hello' to Bob.
+
+ Produce the following output:
+
+
+ """
+ )
+
+
+def test_prompt_render_with_model_output() -> None:
+ class Person(rg.Model):
+ name: str = rg.element()
+ age: int = rg.element()
+
+ @rg.prompt
+ async def create_person(name: str, age: int) -> Person:
+ """Create a person."""
+ ...
+
+ rendered = create_person.render("Alice", 30)
+ assert rendered == dedent(
+ """\
+ Create a person.
+
+ Alice
+
+ 30
+
+ Produce the following output:
+
+
+
+
+
+ """
+ )
+
+
+def test_prompt_render_with_list_output() -> None:
+ @rg.prompt
+ async def generate_numbers(count: int) -> list[int]:
+ """Generate a list of numbers."""
+ ...
+
+ rendered = generate_numbers.render(5)
+ assert rendered == dedent(
+ """\
+ Generate a list of numbers.
+
+ 5
+
+ Produce the following output for each item:
+
+
+ """
+ )
+
+
+def test_prompt_render_with_tuple_output() -> None:
+ @rg.prompt
+ async def create_user(username: str) -> tuple[str, int]:
+ """Create a new user."""
+ ...
+
+ rendered = create_user.render("johndoe")
+ assert rendered == dedent(
+ """\
+ Create a new user.
+
+ johndoe
+
+ Produce the following outputs:
+
+
+
+
+ """
+ )
+
+
+def test_prompt_render_with_tuple_output_ctx() -> None:
+ @rg.prompt
+ async def create_user(username: str) -> tuple[Annotated[str, rg.Ctx(tag="id")], int]:
+ """Create a new user."""
+ ...
+
+ rendered = create_user.render("johndoe")
+ assert rendered == dedent(
+ """\
+ Create a new user.
+
+ johndoe
+
+ Produce the following outputs:
+
+
+
+
+ """
+ )
+
+
+def test_prompt_render_with_dataclass_output() -> None:
+ @dataclass
+ class User:
+ username: str
+ email: str
+ age: int
+
+ @rg.prompt
+ async def register_user(username: str, email: str, age: int) -> User:
+ """Register a new user: {{ username}}."""
+ ...
+
+ rendered = register_user.render("johndoe", "johndoe@example.com", 25)
+ assert rendered == dedent(
+ """\
+ Register a new user: johndoe.
+
+ johndoe@example.com
+
+ 25
+
+ Produce the following outputs:
+
+
+
+
+
+
+ """
+ )
+
+
+def test_prompt_render_with_chat_return() -> None:
+ @rg.prompt
+ async def foo(input_: str) -> Chat:
+ """Do something."""
+ ...
+
+ rendered = foo.render("bar")
+ assert rendered == dedent(
+ """\
+ Do something.
+
+ bar
+ """
+ )
+
+
+def test_prompt_parse_fail_nested_input() -> None:
+ async def foo(arg: list[list[str]]) -> Chat:
+ ...
+
+ with pytest.raises(TypeError):
+ rg.prompt(foo)
+
+ async def bar(arg: tuple[int, str, tuple[str]]) -> Chat:
+ ...
+
+ with pytest.raises(TypeError):
+ rg.prompt(bar)
+
+
+def test_prompt_parse_fail_unique_ouput() -> None:
+ async def foo(arg: int) -> tuple[str, str]:
+ ...
+
+ with pytest.raises(TypeError):
+ rg.prompt(foo)