Skip to content

Commit

Permalink
More runtime improvements (#713)
Browse files Browse the repository at this point in the history
* Hopefully improve reliability and debugging output a bit.

* Add ability to dump threads so we can debug when stuck.

* Fix a run journal issue that blocks processing of empty responses.

* Treat Google no-completions responses as refusals, as they are consistent and specific to particular hazard categories.

* Hoping this fixes the timeout issue.

* And this fixes the index error.
  • Loading branch information
wpietri authored Nov 24, 2024
1 parent 126a30c commit 498c262
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 9 deletions.
9 changes: 7 additions & 2 deletions plugins/google/modelgauge/suts/google_genai_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import google.generativeai as genai # type: ignore
from abc import abstractmethod
from typing import Dict, List, Optional

import google.generativeai as genai # type: ignore
from google.generativeai.types import HarmCategory, HarmBlockThreshold # type: ignore
from pydantic import BaseModel
from typing import Dict, List, Optional

from modelgauge.general import APIException
from modelgauge.prompt import TextPrompt
Expand Down Expand Up @@ -128,6 +129,10 @@ def translate_response(self, request: GoogleGenAiRequest, response: GoogleGenAiR
f"The candidate does not have any content,"
f" but it's finish reason {candidate.finish_reason} does not qualify as a refusal."
)
if not completions:
# This is apparently a refusal. At least, it's what happens consistently with a set of
# prompts in the CSE, SRC, and SXC hazards
completions = [SUTCompletion(text=REFUSAL_RESPONSE)]
return SUTResponse(completions=completions)


Expand Down
27 changes: 24 additions & 3 deletions plugins/google/tests/test_google_genai_client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import json
from unittest.mock import patch

import pytest
from google.generativeai.protos import Candidate, GenerateContentResponse # type: ignore
from google.generativeai.types import HarmCategory, HarmBlockThreshold, generation_types # type: ignore

from unittest.mock import patch

from modelgauge.general import APIException
from modelgauge.prompt import SUTOptions, TextPrompt
from modelgauge.sut import REFUSAL_RESPONSE, SUTCompletion, SUTResponse
from modelgauge.suts.google_genai_client import ( # type: ignore
GEMINI_HARM_CATEGORIES,
GoogleAiApiKey,
Expand All @@ -15,7 +17,6 @@
GoogleGenAiRequest,
GoogleGenAiResponse,
)
from modelgauge.sut import REFUSAL_RESPONSE, SUTCompletion, SUTResponse

_FINISH_REASON_NORMAL = Candidate.FinishReason.STOP
_FINISH_REASON_SAFETY = Candidate.FinishReason.SAFETY
Expand Down Expand Up @@ -235,6 +236,26 @@ def test_google_genai_translate_response_refusal(google_default_sut, fake_native
assert response == SUTResponse(completions=[SUTCompletion(text=REFUSAL_RESPONSE)])


def test_google_genai_translate_response_no_completions(google_default_sut, fake_native_response_refusal, some_request):
no_completions = GoogleGenAiResponse(
**json.loads(
"""{
"candidates": [],
"usage_metadata": {
"prompt_token_count": 19,
"total_token_count": 19,
"cached_content_token_count": 0,
"candidates_token_count": 0
}
}
"""
)
)
response = google_default_sut.translate_response(some_request, no_completions)

assert response == SUTResponse(completions=[SUTCompletion(text=REFUSAL_RESPONSE)])


def test_google_genai_disabled_safety_translate_response_refusal_raises_exception(
google_disabled_safety_sut, fake_native_response_refusal, some_request
):
Expand Down
4 changes: 4 additions & 0 deletions plugins/mistral/modelgauge/suts/mistral_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def client(self) -> Mistral:
if not self._client:
self._client = Mistral(
api_key=self.api_key,
timeout_ms=BACKOFF_MAX_ELAPSED_MILLIS * 3,
retry_config=RetryConfig(
"backoff",
BackoffStrategy(
Expand All @@ -50,6 +51,9 @@ def client(self) -> Mistral:

def request(self, req: dict):
response = None
if self.client.chat.sdk_configuration._hooks.before_request_hooks:
# work around bug in client
self.client.chat.sdk_configuration._hooks.before_request_hooks = []
try:
response = self.client.chat.complete(**req)
return response
Expand Down
11 changes: 7 additions & 4 deletions src/modelbench/run_journal.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ def for_journal(o):
if isinstance(o, TestRunItem):
return {"test": o.test.uid, "item": o.source_id(), "sut": o.sut.uid}
if isinstance(o, SUTResponse):
completion = o.completions[0]
result = {"response_text": completion.text}
if completion.top_logprobs is not None:
result["logprobs"] = for_journal(completion.top_logprobs)
if o.completions:
completion = o.completions[0]
result = {"response_text": completion.text}
if completion.top_logprobs is not None:
result["logprobs"] = for_journal(completion.top_logprobs)
else:
result = {"response_text": None}
return result
elif isinstance(o, BaseModel):
return for_journal(o.model_dump(exclude_defaults=True, exclude_none=True))
Expand Down
3 changes: 3 additions & 0 deletions tests/modelbench_tests/test_run_journal.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def test_sut_response(self):
assert logprobs["token"] == "f"
assert logprobs["logprob"] == 1.0

def test_defective_sut_response(self):
assert for_journal(SUTResponse(completions=[])) == {"response_text": None}

def test_exception(self):
f = getframeinfo(currentframe())
try:
Expand Down

0 comments on commit 498c262

Please sign in to comment.