Skip to content

Commit

Permalink
Fixes together AI api wrapper for multiple choice questions
Browse files Browse the repository at this point in the history
- put the options into the prompt
- evaluate log_probs of the modified prompt
- fixed the prompt length adjustment bug
- Fixes #95

PiperOrigin-RevId: 688130886
Change-Id: Iaf91c3e6d6c9137981b82fe9dc23b7de882cb451
  • Loading branch information
vezhnick authored and copybara-github committed Oct 21, 2024
1 parent bad93f9 commit 281847f
Showing 1 changed file with 77 additions and 74 deletions.
151 changes: 77 additions & 74 deletions concordia/language_model/together_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,81 @@ def sample_text(

return result

def _sample_choice(
self, prompt: str, response: str) -> float:
"""Returns the log probability of the prompt and response."""
original_prompt = prompt
augmented_prompt = _ensure_prompt_not_too_long(prompt, len(response))
attempts = 0
for attempts in range(_MAX_ATTEMPTS):
if attempts > 0:
seconds_to_sleep = (_SECONDS_TO_SLEEP_WHEN_RATE_LIMITED +
random.uniform(-_JITTER_SECONDS, _JITTER_SECONDS))
if attempts >= _NUM_SILENT_ATTEMPTS:
print(
f'Sleeping for {seconds_to_sleep} seconds.. '
+ f'attempt: {attempts} / {_MAX_ATTEMPTS}'
)
time.sleep(seconds_to_sleep)
try:
messages = [
{
'role': 'system',
'content': (
'You always continue sentences provided '
+ 'by the user and you never repeat what '
+ 'the user already said.'
),
},
{
'role': 'user',
'content': 'Question: Is Jake a turtle?\nAnswer: Jake is ',
},
{'role': 'assistant', 'content': 'not a turtle.'},
{
'role': 'user',
'content': (
'Question: What is Priya doing right now?\nAnswer: '
+ 'Priya is currently '
),
},
{'role': 'assistant', 'content': 'sleeping.'},
{'role': 'user', 'content': augmented_prompt},
{'role': 'assistant', 'content': response},
]
result = self._client.chat.completions.create(
model=self._model_name,
messages=messages,
max_tokens=1,
seed=None,
logprobs=1,
stream=False,
echo=True,
)
except (together.error.RateLimitError,
together.error.APIError,
together.error.ServiceUnavailableError) as err:
if attempts >= _NUM_SILENT_ATTEMPTS:
print(f' Exception: {err}')
print(f' Choice exception prompt: {augmented_prompt}')
if isinstance(err, together.error.APIError):
# If hit the error that arises from a prompt that is too long then
# re-run the trimming function with a more pessimistic guess of the
# the number of characters per token.
augmented_prompt = _ensure_prompt_not_too_long(
original_prompt, 1, guess_chars_per_token=1
)
continue
else:
# removing the first token since it is always scored with None.
score = sum(result.prompt[0].logprobs.token_logprobs[1:])
return score

raise language_model.InvalidResponseError(
f'Failed to get logprobs after {attempts+1} attempts.\n Exception'
f' prompt: {augmented_prompt}'
)

@override
def sample_choice(
self,
Expand All @@ -216,82 +291,10 @@ def sample_choice(
seed: int | None = None,
) -> tuple[int, str, dict[str, float]]:

def _sample_choice(response: str) -> float:
augmented_prompt = prompt + response
original_augmented_prompt = augmented_prompt
augmented_prompt = _ensure_prompt_not_too_long(augmented_prompt, 1)
messages = [
{
'role': 'system',
'content': (
'You always continue sentences provided '
+ 'by the user and you never repeat what '
+ 'the user already said.'
),
},
{
'role': 'user',
'content': 'Question: Is Jake a turtle?\nAnswer: Jake is ',
},
{'role': 'assistant', 'content': 'not a turtle.'},
{
'role': 'user',
'content': (
'Question: What is Priya doing right now?\nAnswer: '
+ 'Priya is currently '
),
},
{'role': 'assistant', 'content': 'sleeping.'},
{'role': 'user', 'content': augmented_prompt},
]

result = None
for attempts in range(_MAX_ATTEMPTS):
if attempts > 0:
seconds_to_sleep = (_SECONDS_TO_SLEEP_WHEN_RATE_LIMITED +
random.uniform(-_JITTER_SECONDS, _JITTER_SECONDS))
if attempts >= _NUM_SILENT_ATTEMPTS:
print(
f'Sleeping for {seconds_to_sleep} seconds.. '
+ f'attempt: {attempts} / {_MAX_ATTEMPTS}'
)
time.sleep(seconds_to_sleep)
try:
result = self._client.chat.completions.create(
model=self._model_name,
messages=messages,
max_tokens=1,
seed=seed,
logprobs=1,
stream=False,
)
except (together.error.RateLimitError,
together.error.APIError,
together.error.ServiceUnavailableError) as err:
if attempts >= _NUM_SILENT_ATTEMPTS:
print(f' Exception: {err}')
print(f' Choice exception prompt: {augmented_prompt}')
if isinstance(err, together.error.APIError):
# If hit the error that arises from a prompt that is too long then
# re-run the trimming function with a more pessimistic guess of the
# the number of characters per token.
augmented_prompt = _ensure_prompt_not_too_long(
original_augmented_prompt, 1, guess_chars_per_token=1)
continue
else:
break

if result:
lp = sum(result.choices[0].logprobs.token_logprobs)
else:
raise ValueError(
f'Failed to get logprobs.\nException prompt: {augmented_prompt}')

return lp

sample_choice_for_prompt = lambda x: self._sample_choice(prompt, x)
with concurrent.futures.ThreadPoolExecutor() as executor:
logprobs_np = np.array(
list(executor.map(_sample_choice, responses))
list(executor.map(sample_choice_for_prompt, responses))
).reshape(-1)

idx = np.argmax(logprobs_np)
Expand Down

0 comments on commit 281847f

Please sign in to comment.