Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use actual heldback prompts in official tests #696

Merged
merged 4 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions src/modelgauge/external_data.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import requests # type: ignore
import shutil
import tempfile
import urllib.request
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional
from typing import Dict, Optional

import gdown # type: ignore
from tenacity import retry, stop_after_attempt, wait_exponential

from modelgauge.data_packing import DataDecompressor, DataUnpacker
from modelgauge.general import UrlRetrieveProgressBar


@dataclass(frozen=True, kw_only=True)
Expand All @@ -31,18 +30,25 @@ class WebData(ExternalData):
"""External data that can be trivially downloaded using wget."""

source_url: str
headers: Optional[Dict] = None

@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=1),
reraise=True,
)
def download(self, location):
urllib.request.urlretrieve(
self.source_url,
location,
reporthook=UrlRetrieveProgressBar(self.source_url),
)
if self.headers:
response = requests.get(self.source_url, headers=self.headers)
else:
response = requests.get(self.source_url)
if response.ok:
with open(location, "wb") as f:
f.write(response.content)
else:
raise RuntimeError(
f"failed to fetch {self.source_url} with headers={self.headers}.\nResponse status: {response.status_code}: {response.text}"
)


@dataclass(frozen=True, kw_only=True)
Expand Down
35 changes: 30 additions & 5 deletions src/modelgauge/tests/safe_v1.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import csv
from abc import ABC
from enum import Enum
from typing import Dict, List, Mapping
from typing import Dict, List, Mapping, Optional

from modelgauge.aggregations import get_measurement_stats_by_key
from modelgauge.annotator_set import AnnotatorSet
Expand All @@ -10,6 +10,7 @@
from modelgauge.dependency_helper import DependencyHelper
from modelgauge.external_data import ExternalData, WebData
from modelgauge.prompt import SUTOptions, TextPrompt
from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription
from modelgauge.single_turn_prompt_response import MeasuredTestItem, PromptWithContext, TestItem, TestItemAnnotations
from modelgauge.sut_capabilities import AcceptsTextPrompt
from modelgauge.test_decorator import modelgauge_test
Expand All @@ -18,10 +19,20 @@

PROMPT_SETS = {
"practice": "airr_official_1.0_practice_prompt_set_release",
"official": "mock_airr_official_1.0_heldback_prompt_set_release", # TODO: SWAP OUT FOR REAL FILE after its uploaded
"official": "airr_official_1.0_heldback_prompt_set_release",
}


class ModellabFileDownloadToken(RequiredSecret):
@classmethod
def description(cls) -> SecretDescription:
return SecretDescription(
scope="modellab_files",
key="token",
instructions="Please ask MLCommons admin for permission.",
)


class Locale(Enum):
EN_US = "en_US" # English, United States
FR_FR = "fr_FR" # French, France
Expand Down Expand Up @@ -92,7 +103,13 @@ def create_uid(hazard: str, locale: str | Locale, prompt_set: str, evaluator=Non
return f"safe-{hazard}-{locale}-{prompt_set}-1.0{postfix}".lower()

def __init__(
self, uid: str, hazard: str, locale: Locale, persona_types: List[SafePersonasVersion1], prompt_set: str
self,
uid: str,
hazard: str,
locale: Locale,
persona_types: List[SafePersonasVersion1],
prompt_set: str,
token: Optional[ModellabFileDownloadToken] = None,
):
self._check_annotators()
super().__init__(uid)
Expand All @@ -106,6 +123,7 @@ def __init__(
self.persona_types = persona_types
assert prompt_set in PROMPT_SETS, f"Invalid prompt set {prompt_set}. Must be one of {PROMPT_SETS.keys()}."
self.prompt_set_file_name = PROMPT_SETS[prompt_set]
self.token = token

@classmethod
def _check_annotators(cls):
Expand All @@ -115,10 +133,14 @@ def _check_annotators(cls):

def get_dependencies(self) -> Mapping[str, ExternalData]:
modellab_base_download_url = "https://modellab.modelmodel.org/files/download"
headers = None
if self.token is not None:
headers = {"auth-token": self.token.value}
# Only one dependency.
return {
self.prompt_set_file_name: WebData(
source_url=f"{modellab_base_download_url}/{self.prompt_set_file_name}.csv"
source_url=f"{modellab_base_download_url}/{self.prompt_set_file_name}.csv",
headers=headers,
)
}

Expand Down Expand Up @@ -204,7 +226,10 @@ def register_tests(cls, evaluator=None):
test_uid = BaseSafeTestVersion1.create_uid(hazard, locale, prompt_set, evaluator)
# TODO: Remove this 'if', duplicates are already caught during registration and should raise errors.
if not test_uid in TESTS.keys():
TESTS.register(cls, test_uid, hazard, locale, ALL_PERSONAS, prompt_set)
token = None
if prompt_set == "official":
token = InjectSecret(ModellabFileDownloadToken)
TESTS.register(cls, test_uid, hazard, locale, ALL_PERSONAS, prompt_set, token)


# default llama guard annotator, always
Expand Down
4 changes: 2 additions & 2 deletions tests/modelbench_tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_benchmark_definition_basics():


@pytest.mark.parametrize("prompt_set", PROMPT_SETS.keys())
def test_benchmark_v1_definition_basics(prompt_set):
def test_benchmark_v1_definition_basics(prompt_set, fake_secrets):
mbb = GeneralPurposeAiChatBenchmarkV1(Locale.EN_US, prompt_set)
assert mbb.uid == f"general_purpose_ai_chat_benchmark-1.0-en_us-{prompt_set}-default"
assert mbb.name() == "General Purpose Ai Chat Benchmark V 1"
Expand All @@ -66,7 +66,7 @@ def test_benchmark_v1_definition_basics(prompt_set):
assert hazard.hazard_key == hazard_key
assert hazard.locale == Locale.EN_US
assert hazard.prompt_set == prompt_set
assert prompt_set in hazard.tests(secrets={})[0].prompt_set_file_name
assert prompt_set in hazard.tests(secrets=fake_secrets)[0].prompt_set_file_name


@pytest.mark.parametrize(
Expand Down
Binary file modified tests/modelgauge_tests/data/sample_cache.sqlite
Binary file not shown.
16 changes: 12 additions & 4 deletions tests/modelgauge_tests/test_external_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,22 @@
from tenacity import wait_none


WebDataMockResponse = namedtuple("WebDataMockResponse", ("ok", "content"))
GDriveFileToDownload = namedtuple("GDriveFileToDownload", ("id", "path"))


def test_web_data_download(mocker):
mock_download = mocker.patch("urllib.request.urlretrieve")
def test_web_data_download(mocker, tmpdir):
mock_download = mocker.patch("requests.get", return_value=WebDataMockResponse(ok=True, content=b"test"))
web_data = WebData(source_url="http://example.com")
web_data.download("test.tgz")
mock_download.assert_called_once_with("http://example.com", "test.tgz", reporthook=ANY)
web_data.download(tmpdir / "file.txt")
mock_download.assert_called_once_with("http://example.com")


def test_web_data_download_with_headers(mocker, tmpdir):
mock_download = mocker.patch("requests.get", return_value=WebDataMockResponse(ok=True, content=b"test"))
web_data = WebData(source_url="http://example.com", headers={"token": "secret"})
web_data.download(tmpdir / "file.txt")
mock_download.assert_called_once_with("http://example.com", headers={"token": "secret"})


def test_gdrive_data_download(mocker):
Expand Down