Skip to content

Commit

Permalink
Use actual heldback prompts in official tests (#696)
Browse files Browse the repository at this point in the history
* inject token for v1 tests that use official prompts

* WebData can optionally pass headers

* fix tests

* mypy
  • Loading branch information
bkorycki authored Nov 19, 2024
1 parent e469bc8 commit a4e1eca
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 19 deletions.
22 changes: 14 additions & 8 deletions src/modelgauge/external_data.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import requests # type: ignore
import shutil
import tempfile
import urllib.request
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional
from typing import Dict, Optional

import gdown # type: ignore
from tenacity import retry, stop_after_attempt, wait_exponential

from modelgauge.data_packing import DataDecompressor, DataUnpacker
from modelgauge.general import UrlRetrieveProgressBar


@dataclass(frozen=True, kw_only=True)
Expand All @@ -31,18 +30,25 @@ class WebData(ExternalData):
"""External data that can be trivially downloaded using wget."""

source_url: str
headers: Optional[Dict] = None

@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=1),
reraise=True,
)
def download(self, location):
urllib.request.urlretrieve(
self.source_url,
location,
reporthook=UrlRetrieveProgressBar(self.source_url),
)
if self.headers:
response = requests.get(self.source_url, headers=self.headers)
else:
response = requests.get(self.source_url)
if response.ok:
with open(location, "wb") as f:
f.write(response.content)
else:
raise RuntimeError(
f"failed to fetch {self.source_url} with headers={self.headers}.\nResponse status: {response.status_code}: {response.text}"
)


@dataclass(frozen=True, kw_only=True)
Expand Down
35 changes: 30 additions & 5 deletions src/modelgauge/tests/safe_v1.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import csv
from abc import ABC
from enum import Enum
from typing import Dict, List, Mapping
from typing import Dict, List, Mapping, Optional

from modelgauge.aggregations import get_measurement_stats_by_key
from modelgauge.annotator_set import AnnotatorSet
Expand All @@ -10,6 +10,7 @@
from modelgauge.dependency_helper import DependencyHelper
from modelgauge.external_data import ExternalData, WebData
from modelgauge.prompt import SUTOptions, TextPrompt
from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription
from modelgauge.single_turn_prompt_response import MeasuredTestItem, PromptWithContext, TestItem, TestItemAnnotations
from modelgauge.sut_capabilities import AcceptsTextPrompt
from modelgauge.test_decorator import modelgauge_test
Expand All @@ -18,10 +19,20 @@

PROMPT_SETS = {
"practice": "airr_official_1.0_practice_prompt_set_release",
"official": "mock_airr_official_1.0_heldback_prompt_set_release", # TODO: SWAP OUT FOR REAL FILE after its uploaded
"official": "airr_official_1.0_heldback_prompt_set_release",
}


class ModellabFileDownloadToken(RequiredSecret):
@classmethod
def description(cls) -> SecretDescription:
return SecretDescription(
scope="modellab_files",
key="token",
instructions="Please ask MLCommons admin for permission.",
)


class Locale(Enum):
EN_US = "en_US" # English, United States
FR_FR = "fr_FR" # French, France
Expand Down Expand Up @@ -92,7 +103,13 @@ def create_uid(hazard: str, locale: str | Locale, prompt_set: str, evaluator=Non
return f"safe-{hazard}-{locale}-{prompt_set}-1.0{postfix}".lower()

def __init__(
self, uid: str, hazard: str, locale: Locale, persona_types: List[SafePersonasVersion1], prompt_set: str
self,
uid: str,
hazard: str,
locale: Locale,
persona_types: List[SafePersonasVersion1],
prompt_set: str,
token: Optional[ModellabFileDownloadToken] = None,
):
self._check_annotators()
super().__init__(uid)
Expand All @@ -106,6 +123,7 @@ def __init__(
self.persona_types = persona_types
assert prompt_set in PROMPT_SETS, f"Invalid prompt set {prompt_set}. Must be one of {PROMPT_SETS.keys()}."
self.prompt_set_file_name = PROMPT_SETS[prompt_set]
self.token = token

@classmethod
def _check_annotators(cls):
Expand All @@ -115,10 +133,14 @@ def _check_annotators(cls):

def get_dependencies(self) -> Mapping[str, ExternalData]:
modellab_base_download_url = "https://modellab.modelmodel.org/files/download"
headers = None
if self.token is not None:
headers = {"auth-token": self.token.value}
# Only one dependency.
return {
self.prompt_set_file_name: WebData(
source_url=f"{modellab_base_download_url}/{self.prompt_set_file_name}.csv"
source_url=f"{modellab_base_download_url}/{self.prompt_set_file_name}.csv",
headers=headers,
)
}

Expand Down Expand Up @@ -204,7 +226,10 @@ def register_tests(cls, evaluator=None):
test_uid = BaseSafeTestVersion1.create_uid(hazard, locale, prompt_set, evaluator)
# TODO: Remove this 'if', duplicates are already caught during registration and should raise errors.
if not test_uid in TESTS.keys():
TESTS.register(cls, test_uid, hazard, locale, ALL_PERSONAS, prompt_set)
token = None
if prompt_set == "official":
token = InjectSecret(ModellabFileDownloadToken)
TESTS.register(cls, test_uid, hazard, locale, ALL_PERSONAS, prompt_set, token)


# default llama guard annotator, always
Expand Down
4 changes: 2 additions & 2 deletions tests/modelbench_tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_benchmark_definition_basics():


@pytest.mark.parametrize("prompt_set", PROMPT_SETS.keys())
def test_benchmark_v1_definition_basics(prompt_set):
def test_benchmark_v1_definition_basics(prompt_set, fake_secrets):
mbb = GeneralPurposeAiChatBenchmarkV1(Locale.EN_US, prompt_set)
assert mbb.uid == f"general_purpose_ai_chat_benchmark-1.0-en_us-{prompt_set}-default"
assert mbb.name() == "General Purpose Ai Chat Benchmark V 1"
Expand All @@ -66,7 +66,7 @@ def test_benchmark_v1_definition_basics(prompt_set):
assert hazard.hazard_key == hazard_key
assert hazard.locale == Locale.EN_US
assert hazard.prompt_set == prompt_set
assert prompt_set in hazard.tests(secrets={})[0].prompt_set_file_name
assert prompt_set in hazard.tests(secrets=fake_secrets)[0].prompt_set_file_name


@pytest.mark.parametrize(
Expand Down
Binary file modified tests/modelgauge_tests/data/sample_cache.sqlite
Binary file not shown.
16 changes: 12 additions & 4 deletions tests/modelgauge_tests/test_external_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,22 @@
from tenacity import wait_none


WebDataMockResponse = namedtuple("WebDataMockResponse", ("ok", "content"))
GDriveFileToDownload = namedtuple("GDriveFileToDownload", ("id", "path"))


def test_web_data_download(mocker):
mock_download = mocker.patch("urllib.request.urlretrieve")
def test_web_data_download(mocker, tmpdir):
mock_download = mocker.patch("requests.get", return_value=WebDataMockResponse(ok=True, content=b"test"))
web_data = WebData(source_url="http://example.com")
web_data.download("test.tgz")
mock_download.assert_called_once_with("http://example.com", "test.tgz", reporthook=ANY)
web_data.download(tmpdir / "file.txt")
mock_download.assert_called_once_with("http://example.com")


def test_web_data_download_with_headers(mocker, tmpdir):
mock_download = mocker.patch("requests.get", return_value=WebDataMockResponse(ok=True, content=b"test"))
web_data = WebData(source_url="http://example.com", headers={"token": "secret"})
web_data.download(tmpdir / "file.txt")
mock_download.assert_called_once_with("http://example.com", headers={"token": "secret"})


def test_gdrive_data_download(mocker):
Expand Down

0 comments on commit a4e1eca

Please sign in to comment.