Skip to content

Commit

Permalink
Interface improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
monoxgas committed Apr 24, 2024
1 parent 309ef4b commit ff8641d
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 22 deletions.
79 changes: 58 additions & 21 deletions rigging/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,32 +61,20 @@ def restart(self) -> "PendingChat":
raise ValueError("Cannot restart chat that was not created with a PendingChat")
return PendingChat(self.pending_chat.generator, self.messages, self.pending_chat.params)

@t.overload
def continue_(self, messages: t.Sequence[MessageDict]) -> "PendingChat":
...

@t.overload
def continue_(self, messages: MessageDict) -> "PendingChat":
...
# TODO: Why are these overloads here? I wonder if IDEs preferred them

@t.overload
def continue_(self, messages: t.Sequence[Message]) -> "PendingChat":
...

@t.overload
def continue_(self, messages: Message) -> "PendingChat":
...

def continue_(
self, messages: t.Sequence[Message] | t.Sequence[MessageDict] | Message | MessageDict
def fork(
self, messages: t.Sequence[Message] | t.Sequence[MessageDict] | Message | MessageDict | str
) -> "PendingChat":
if self.pending_chat is None:
raise ValueError("Cannot continue chat that was not created with a PendingChat")

messages_list: list[Message] = (
Message.fit_list(messages) if isinstance(messages, t.Sequence) else [Message.fit(messages)]
)
return PendingChat(self.pending_chat.generator, self.all + messages_list, self.pending_chat.params)
pending = PendingChat(self.pending_chat.generator, self.all, self.pending_chat.params)
pending.add(messages)
return pending

def continue_(self, messages: t.Sequence[Message] | t.Sequence[MessageDict] | Message | str) -> "PendingChat":
return self.fork(messages)

def clone(self) -> "Chat":
return Chat(
Expand Down Expand Up @@ -151,6 +139,32 @@ def with_params(self, params: "GenerateParams") -> "PendingChat":
self.params = params
return self

def add(
self, messages: t.Sequence[MessageDict] | t.Sequence[Message] | MessageDict | Message | str
) -> "PendingChat":
message_list: list[Message] = (
[Message.fit(messages)]
if not isinstance(messages, t.Sequence) or isinstance(messages, str)
else Message.fit_list(messages)
)
# If the last message is the same role as the first new message, append to it
if self.chat.next_messages and self.chat.next_messages[-1].role == message_list[0].role:
self.chat.next_messages[-1].content += "\n" + message_list[0].content
message_list = message_list[1:]
else:
self.chat.next_messages += message_list
return self

def fork(
self, messages: t.Sequence[MessageDict] | t.Sequence[Message] | MessageDict | Message | str
) -> "PendingChat":
return self.clone().add(messages)

def continue_(
self, messages: t.Sequence[MessageDict] | t.Sequence[Message] | MessageDict | Message | str
) -> "PendingChat":
return self.fork(messages)

def clone(self) -> "PendingChat":
new = PendingChat(self.generator, [], self.params)
new.chat = self.chat.clone()
Expand Down Expand Up @@ -334,6 +348,29 @@ def _execute(self) -> list[Message]:

return new_messages

@t.overload
def run_with(
self,
messages: t.Sequence[MessageDict] | t.Sequence[Message] | MessageDict | Message | str,
count: t.Literal[None] = None,
) -> Chat:
...

@t.overload
def run_with(
self,
messages: t.Sequence[MessageDict] | t.Sequence[Message] | MessageDict | Message | str,
count: int,
) -> list[Chat]:
...

def run_with(
self,
messages: t.Sequence[MessageDict] | t.Sequence[Message] | MessageDict | Message | str,
count: int | None = None,
) -> Chat | list[Chat]:
return self.add(messages).run(count)

@t.overload
def run(self, count: t.Literal[None] = None) -> Chat:
...
Expand Down
4 changes: 3 additions & 1 deletion rigging/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ def fit_list(cls, messages: t.Sequence["Message"] | t.Sequence[MessageDict]) ->
return [cls.fit(message) for message in messages]

@classmethod
def fit(cls, message: t.Union["Message", MessageDict]) -> "Message":
def fit(cls, message: t.Union["Message", MessageDict, str]) -> "Message":
if isinstance(message, str):
return cls(role="user", content=message)
return cls(**message) if isinstance(message, dict) else message


Expand Down
18 changes: 18 additions & 0 deletions tests/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,24 @@ def test_chat_continue() -> None:
assert continued.all[2].content == "How are you?"


def test_pending_chat_continue() -> None:
pending = PendingChat(get_generator("gpt-3.5"), [], GenerateParams())
continued = pending.continue_([Message("user", "Hello")])

assert continued != pending
assert len(continued.chat) == 1
assert continued.chat.all[0].content == "Hello"


def test_pending_chat_add() -> None:
pending = PendingChat(get_generator("gpt-3.5"), [Message("user", "Hello")], GenerateParams())
added = pending.add(Message("user", "Hello"))

assert added == pending
assert len(added.chat) == 2
assert added.chat.all[0].content == "Hello"


def test_chat_continue_maintains_parsed_models() -> None:
chat = Chat(
[
Expand Down

0 comments on commit ff8641d

Please sign in to comment.