Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OPIK-494] raise error when prompt format is called with the wrong arguments #768

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions sdks/python/src/opik/api_objects/prompt/prompt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any

from opik.rest_api import PromptVersionDetail
from . import prompt_template


class Prompt:
Expand Down Expand Up @@ -31,7 +32,8 @@ def __init__(
prompt=prompt,
)
self._name = new_instance.name
self._prompt = new_instance.prompt
self._prompt = prompt_template.PromptTemplate(prompt)

self._commit = new_instance.commit
self.__internal_api__version_id__: str = (
new_instance.__internal_api__version_id__
Expand All @@ -46,7 +48,7 @@ def name(self) -> str:
@property
def prompt(self) -> str:
"""The latest template of the prompt."""
return self._prompt
return str(self._prompt)

@property
def commit(self) -> str:
Expand All @@ -64,10 +66,7 @@ def format(self, **kwargs: Any) -> str:
Returns:
A string with all placeholders replaced by their corresponding values from kwargs.
"""
template = self._prompt
for key, value in kwargs.items():
template = template.replace(f"{{{{{key}}}}}", str(value))
return template
return self._prompt.format(**kwargs)

@classmethod
def from_fern_prompt_version(
Expand All @@ -81,7 +80,7 @@ def from_fern_prompt_version(
prompt.__internal_api__version_id__ = prompt_version.id
prompt.__internal_api__prompt_id__ = prompt_version.prompt_id
prompt._name = name
prompt._prompt = prompt_version.template
prompt._prompt = prompt_template.PromptTemplate(prompt_version.template)
prompt._commit = prompt_version.commit

return prompt
34 changes: 34 additions & 0 deletions sdks/python/src/opik/api_objects/prompt/prompt_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Set, Any
from opik import exceptions

import re


class PromptTemplate:
def __init__(self, template: str) -> None:
self._template = template
self._placeholders = _extract_placeholder_keys(template)

def format(self, **kwargs: Any) -> str:
template = self._template
placeholders = self._placeholders

kwargs_keys: Set[str] = set(kwargs.keys())

if kwargs_keys != placeholders:
raise exceptions.PromptPlaceholdersDontMatchFormatArguments(
prompt_placeholders=placeholders, format_arguments=kwargs_keys
)

for key, value in kwargs.items():
template = template.replace(f"{{{{{key}}}}}", str(value))

return template

def __str__(self) -> str:
return self._template


def _extract_placeholder_keys(prompt_template: str) -> Set[str]:
pattern = r"\{\{(.*?)\}\}"
return set(re.findall(pattern, prompt_template))
20 changes: 20 additions & 0 deletions sdks/python/src/opik/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Set


class OpikException(Exception):
pass

Expand All @@ -16,3 +19,20 @@ class ConfigurationError(OpikException):

class ScoreMethodMissingArguments(OpikException):
pass


class PromptPlaceholdersDontMatchFormatArguments(OpikException):
def __init__(self, prompt_placeholders: Set[str], format_arguments: Set[str]):
self.prompt_placeholders = prompt_placeholders
self.format_arguments = format_arguments
self.symmetric_difference = prompt_placeholders.symmetric_difference(
format_arguments
)

def __str__(self) -> str:
return (
f"The `prompt.format(**kwargs)` arguments must exactly match the prompt placeholders."
f"Prompt placeholders: {list(self.prompt_placeholders)}. "
f"Format arguments: {list(self.format_arguments)}"
f"Difference: {list(self.symmetric_difference)}"
)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest
from opik.api_objects.prompt import prompt_template
from opik import exceptions


def test_prompt__format__happyflow():
PROMPT_TEMPLATE = "Hi, my name is {{name}}, I live in {{city}}."

tested = prompt_template.PromptTemplate(PROMPT_TEMPLATE)

result = tested.format(name="Harry", city="London")
assert result == "Hi, my name is Harry, I live in London."


def test_prompt__format__passed_arguments_that_are_not_in_template__error_raised_with_correct_report_info():
PROMPT_TEMPLATE = "Hi, my name is {{name}}, I live in {{city}}."

tested = prompt_template.PromptTemplate(PROMPT_TEMPLATE)

with pytest.raises(
exceptions.PromptPlaceholdersDontMatchFormatArguments
) as exc_info:
tested.format(name="Harry", city="London", nemesis_name="Voldemort")

assert exc_info.value.format_arguments == set(["name", "city", "nemesis_name"])
assert exc_info.value.prompt_placeholders == set(
[
"name",
"city",
]
)
assert exc_info.value.symmetric_difference == set(["nemesis_name"])


def test_prompt__format__some_placeholders_dont_have_corresponding_format_arguments__error_raised_with_correct_report_info():
PROMPT_TEMPLATE = "Hi, my name is {{name}}, I live in {{city}}."

tested = prompt_template.PromptTemplate(PROMPT_TEMPLATE)

with pytest.raises(
exceptions.PromptPlaceholdersDontMatchFormatArguments
) as exc_info:
tested.format(name="Harry")

assert exc_info.value.format_arguments == set(["name"])
assert exc_info.value.prompt_placeholders == set(["name", "city"])
assert exc_info.value.symmetric_difference == set(["city"])


def test_prompt__format__some_placeholders_dont_have_corresponding_format_arguments_AND_there_are_format_arguments_that_are_not_in_the_template__error_raised_with_correct_report_info():
PROMPT_TEMPLATE = "Hi, my name is {{name}}, I live in {{city}}."

tested = prompt_template.PromptTemplate(PROMPT_TEMPLATE)

with pytest.raises(
exceptions.PromptPlaceholdersDontMatchFormatArguments
) as exc_info:
tested.format(name="Harry", nemesis_name="Voldemort")

assert exc_info.value.format_arguments == set(["name", "nemesis_name"])
assert exc_info.value.prompt_placeholders == set(["name", "city"])
assert exc_info.value.symmetric_difference == set(["city", "nemesis_name"])
Loading