Skip to content

Commit

Permalink
Finalizing new prompt interface
Browse files Browse the repository at this point in the history
  • Loading branch information
monoxgas committed Jun 6, 2024
1 parent a332421 commit b55b782
Show file tree
Hide file tree
Showing 14 changed files with 1,175 additions and 173 deletions.
1 change: 1 addition & 0 deletions docs/api/prompt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: rigging.prompt
Empty file added docs/topics/prompt-functions.md
Empty file.
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ nav:
- Models: topics/models.md
- Generators: topics/generators.md
- Chats and Messages: topics/chats-and-messages.md
- Prompt Functions: topics/prompt-functions.md
- Completions: topics/completions.md
- Callbacks and Mapping: topics/callbacks-and-mapping.md
- Async and Batching: topics/async-and-batching.md
Expand All @@ -26,6 +27,7 @@ nav:
- rigging.generator: api/generator.md
- rigging.model: api/model.md
- rigging.message: api/message.md
- rigging.prompt: api/prompt.md
- rigging.tool: api/tool.md
- rigging.data: api/data.md
- rigging.parsing: api/parsing.md
Expand Down
91 changes: 51 additions & 40 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ asyncssh = { version = "^2.14.2", optional = true }
types-requests = { version = "^2.32.0.20240523", optional = true }
click = { version = "^8.1.7", optional = true }
httpx = { version = "^0.27.0", optional = true }
xmltodict = "^0.13.0"

[tool.poetry.extras]
examples = ["asyncssh", "types-requests", "click", "httpx"]
Expand Down
4 changes: 4 additions & 0 deletions rigging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from rigging.message import Message, MessageDict, Messages
from rigging.model import Model, attr, element, make_primitive, wrapped
from rigging.prompt import Ctx, Prompt, prompt
from rigging.tool import Tool

__version__ = "1.3.0"
Expand Down Expand Up @@ -56,6 +57,9 @@
"chats_to_elastic_data",
"flatten_chats",
"unflatten_chats",
"prompt",
"Prompt",
"Ctx",
]

from loguru import logger
Expand Down
100 changes: 95 additions & 5 deletions rigging/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from elasticsearch import AsyncElasticsearch

from rigging.data import ElasticOpType
from rigging.prompt import Prompt
from rigging.util import P, R

DEFAULT_MAX_ROUNDS = 5
"""Maximum number of internal callback rounds to attempt during generation before giving up."""
Expand Down Expand Up @@ -203,14 +205,28 @@ def continue_(self, messages: t.Sequence[Message] | t.Sequence[MessageDict] | Me
return self.fork(messages, include_all=True)

def clone(self, *, only_messages: bool = False) -> Chat:
"""Creates a deep copy of the chat."""
"""
Creates a deep copy of the chat.
Args:
only_messages: If True, only the messages will be cloned.
If False (default), the entire chat object will be cloned.
Returns:
A new instance of Chat.
"""
new = Chat(
[m.model_copy() for m in self.messages],
[m.model_copy() for m in self.generated],
self.generator,
)
if not only_messages:
new.metadata = deepcopy(self.metadata)
new.params = self.params.model_copy() if self.params is not None else None
new.stop_reason = self.stop_reason
new.usage = self.usage.model_copy() if self.usage is not None else None
new.extra = deepcopy(self.extra)
new.failed = self.failed
return new

def apply(self, **kwargs: str) -> Chat:
Expand Down Expand Up @@ -371,7 +387,7 @@ async def to_elastic(
"""
from rigging.data import chats_to_elastic

return chats_to_elastic(self, index, client, op_type=op_type, create_index=create_index, **kwargs)
return await chats_to_elastic(self, index, client, op_type=op_type, create_index=create_index, **kwargs)


# Callbacks
Expand Down Expand Up @@ -564,7 +580,12 @@ def clone(self, *, only_messages: bool = False) -> ChatPipeline:
Returns:
A new instance of `ChatPipeline` that is a clone of the current instance.
"""
new = ChatPipeline(self.generator, [], params=self.params, watch_callbacks=self.watch_callbacks)
new = ChatPipeline(
self.generator,
[],
params=self.params.model_copy() if self.params is not None else None,
watch_callbacks=self.watch_callbacks,
)
new.chat = self.chat.clone()
if not only_messages:
new.until_callbacks = self.until_callbacks.copy()
Expand All @@ -573,8 +594,8 @@ def clone(self, *, only_messages: bool = False) -> ChatPipeline:
new.inject_tool_prompt = self.inject_tool_prompt
new.force_tool = self.force_tool
new.metadata = deepcopy(self.metadata)
new.then_chat_callbacks = self.then_chat_callbacks.copy()
new.map_chat_callbacks = self.map_chat_callbacks.copy()
new.then_callbacks = self.then_callbacks.copy()
new.map_callbacks = self.map_callbacks.copy()
return new

def meta(self, **kwargs: t.Any) -> ChatPipeline:
Expand Down Expand Up @@ -1157,3 +1178,72 @@ async def run_batch(
chats = [s.chat for s in states if s.chat is not None]

return await self._post_run(chats)

# Generator iteration

async def arun_over(
self,
*generators: Generator | str,
include_original: bool = True,
skip_failed: bool = False,
include_failed: bool = False,
) -> ChatList:
"""
Executes the generation process across multiple generators.
For each generator, this pending chat is cloned and the generator is replaced
before the run call. All callbacks and parameters are preserved.
Parameters:
*generators: A sequence of generators to be used for the generation process.
include_original: Whether to include the original generator in the list of runs.
skip_failed: Enable to ignore any max rounds errors and return only successful chats.
include_failed: Enable to ignore max rounds errors and return both
successful and failed chats.
Returns:
A list of generatated Chats.
"""
if skip_failed and include_failed:
raise ValueError("Cannot use both skip_failed and include_failed")

_generators: list[Generator] = [g if isinstance(g, Generator) else get_generator(g) for g in generators]
if include_original:
_generators.append(self.generator)

coros: list[t.Coroutine[t.Any, t.Any, Chat]] = []
for generator in _generators:
sub = self.clone()
sub.generator = generator
coros.append(sub.run(allow_failed=skip_failed or include_failed))

chats = await asyncio.gather(*coros)

if skip_failed:
chats = [c for c in chats if not c.failed]

return ChatList(chats)

# Prompt functions

def prompt(self, func: t.Callable[P, t.Coroutine[None, None, R]]) -> Prompt[P, R]:
"""
Decorator to convert a function into a prompt bound to this pipeline.
See [rigging.prompt.prompt][] for more information.
Args:
func: The function to be converted into a prompt.
Returns:
The prompt.
"""
from rigging.prompt import prompt

return prompt(func, pipeline=self)

async def run_prompt(self, prompt: Prompt[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
return await prompt.run(*args, pipeline=self, **kwargs)

async def run_prompt_many(self, prompt: Prompt[P, R], count: int, *args: P.args, **kwargs: P.kwargs) -> list[R]:
return await prompt.run_many(count, *args, pipeline=self, **kwargs)
Loading

0 comments on commit b55b782

Please sign in to comment.