From 794e9df323b8a482f5e142a11705d2dbacb5c82e Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Mon, 18 Nov 2024 12:53:36 -0800 Subject: [PATCH 1/4] inject token for v1 tests that use official prompts --- src/modelgauge/tests/safe_v1.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/src/modelgauge/tests/safe_v1.py b/src/modelgauge/tests/safe_v1.py index 0f6b63c7..fd2e2ed5 100644 --- a/src/modelgauge/tests/safe_v1.py +++ b/src/modelgauge/tests/safe_v1.py @@ -1,7 +1,7 @@ import csv from abc import ABC from enum import Enum -from typing import Dict, List, Mapping +from typing import Dict, List, Mapping, Optional from modelgauge.aggregations import get_measurement_stats_by_key from modelgauge.annotator_set import AnnotatorSet @@ -10,6 +10,7 @@ from modelgauge.dependency_helper import DependencyHelper from modelgauge.external_data import ExternalData, WebData from modelgauge.prompt import SUTOptions, TextPrompt +from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription from modelgauge.single_turn_prompt_response import MeasuredTestItem, PromptWithContext, TestItem, TestItemAnnotations from modelgauge.sut_capabilities import AcceptsTextPrompt from modelgauge.test_decorator import modelgauge_test @@ -22,6 +23,16 @@ } +class PrivatePromptsToken(RequiredSecret): + @classmethod + def description(cls) -> SecretDescription: + return SecretDescription( + scope="airr_prompts", + key="token", + instructions="Please ask MLCommons admin for permission.", + ) + + class Locale(Enum): EN_US = "en_US" # English, United States FR_FR = "fr_FR" # French, France @@ -92,7 +103,13 @@ def create_uid(hazard: str, locale: str | Locale, prompt_set: str, evaluator=Non return f"safe-{hazard}-{locale}-{prompt_set}-1.0{postfix}".lower() def __init__( - self, uid: str, hazard: str, locale: Locale, persona_types: List[SafePersonasVersion1], prompt_set: str + self, + uid: str, + hazard: str, + locale: Locale, + persona_types: List[SafePersonasVersion1], + prompt_set: str, + token: Optional[PrivatePromptsToken] = None, ): self._check_annotators() super().__init__(uid) @@ -106,6 +123,7 @@ def __init__( self.persona_types = persona_types assert prompt_set in PROMPT_SETS, f"Invalid prompt set {prompt_set}. Must be one of {PROMPT_SETS.keys()}." self.prompt_set_file_name = PROMPT_SETS[prompt_set] + self.token = token @classmethod def _check_annotators(cls): @@ -114,6 +132,7 @@ def _check_annotators(cls): raise NotImplementedError("Concrete SafeTestVersion1 classes must set class-attribute `annotators`.") def get_dependencies(self) -> Mapping[str, ExternalData]: + # TODO: Pass token in header. modellab_base_download_url = "https://modellab.modelmodel.org/files/download" # Only one dependency. return { @@ -204,7 +223,10 @@ def register_tests(cls, evaluator=None): test_uid = BaseSafeTestVersion1.create_uid(hazard, locale, prompt_set, evaluator) # TODO: Remove this 'if', duplicates are already caught during registration and should raise errors. if not test_uid in TESTS.keys(): - TESTS.register(cls, test_uid, hazard, locale, ALL_PERSONAS, prompt_set) + token = None + if prompt_set == "official": + token = InjectSecret(PrivatePromptsToken) + TESTS.register(cls, test_uid, hazard, locale, ALL_PERSONAS, prompt_set, token) # default llama guard annotator, always From 7ea6879345cc1a3b14f6c388f2417f44cb620592 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Mon, 18 Nov 2024 13:20:58 -0800 Subject: [PATCH 2/4] WebData can optionally pass headers --- src/modelgauge/external_data.py | 22 ++++++++++++++-------- src/modelgauge/tests/safe_v1.py | 17 ++++++++++------- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/src/modelgauge/external_data.py b/src/modelgauge/external_data.py index 3aef331b..52cf1d27 100644 --- a/src/modelgauge/external_data.py +++ b/src/modelgauge/external_data.py @@ -1,15 +1,14 @@ +import requests import shutil import tempfile -import urllib.request from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional +from typing import Dict, Optional import gdown # type: ignore from tenacity import retry, stop_after_attempt, wait_exponential from modelgauge.data_packing import DataDecompressor, DataUnpacker -from modelgauge.general import UrlRetrieveProgressBar @dataclass(frozen=True, kw_only=True) @@ -31,6 +30,7 @@ class WebData(ExternalData): """External data that can be trivially downloaded using wget.""" source_url: str + headers: Optional[Dict] = None @retry( stop=stop_after_attempt(5), @@ -38,11 +38,17 @@ class WebData(ExternalData): reraise=True, ) def download(self, location): - urllib.request.urlretrieve( - self.source_url, - location, - reporthook=UrlRetrieveProgressBar(self.source_url), - ) + if self.headers: + response = requests.get(self.source_url, headers=self.headers) + else: + response = requests.get(self.source_url) + if response.ok: + with open(location, "wb") as f: + f.write(response.content) + else: + raise RuntimeError( + f"failed to fetch {self.source_url} with headers={self.headers}.\nResponse status: {response.status_code}: {response.text}" + ) @dataclass(frozen=True, kw_only=True) diff --git a/src/modelgauge/tests/safe_v1.py b/src/modelgauge/tests/safe_v1.py index fd2e2ed5..492306b9 100644 --- a/src/modelgauge/tests/safe_v1.py +++ b/src/modelgauge/tests/safe_v1.py @@ -19,15 +19,15 @@ PROMPT_SETS = { "practice": "airr_official_1.0_practice_prompt_set_release", - "official": "mock_airr_official_1.0_heldback_prompt_set_release", # TODO: SWAP OUT FOR REAL FILE after its uploaded + "official": "airr_official_1.0_heldback_prompt_set_release", } -class PrivatePromptsToken(RequiredSecret): +class ModellabFileDownloadToken(RequiredSecret): @classmethod def description(cls) -> SecretDescription: return SecretDescription( - scope="airr_prompts", + scope="modellab_files", key="token", instructions="Please ask MLCommons admin for permission.", ) @@ -109,7 +109,7 @@ def __init__( locale: Locale, persona_types: List[SafePersonasVersion1], prompt_set: str, - token: Optional[PrivatePromptsToken] = None, + token: Optional[ModellabFileDownloadToken] = None, ): self._check_annotators() super().__init__(uid) @@ -132,12 +132,15 @@ def _check_annotators(cls): raise NotImplementedError("Concrete SafeTestVersion1 classes must set class-attribute `annotators`.") def get_dependencies(self) -> Mapping[str, ExternalData]: - # TODO: Pass token in header. modellab_base_download_url = "https://modellab.modelmodel.org/files/download" + headers = None + if self.token is not None: + headers = {"auth-token": self.token.value} # Only one dependency. return { self.prompt_set_file_name: WebData( - source_url=f"{modellab_base_download_url}/{self.prompt_set_file_name}.csv" + source_url=f"{modellab_base_download_url}/{self.prompt_set_file_name}.csv", + headers=headers, ) } @@ -225,7 +228,7 @@ def register_tests(cls, evaluator=None): if not test_uid in TESTS.keys(): token = None if prompt_set == "official": - token = InjectSecret(PrivatePromptsToken) + token = InjectSecret(ModellabFileDownloadToken) TESTS.register(cls, test_uid, hazard, locale, ALL_PERSONAS, prompt_set, token) From 62142bb14f260395ca12c3809abc907cea255d7a Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Mon, 18 Nov 2024 13:45:10 -0800 Subject: [PATCH 3/4] fix tests --- tests/modelbench_tests/test_benchmark.py | 4 ++-- .../modelgauge_tests/data/sample_cache.sqlite | Bin 12288 -> 12288 bytes tests/modelgauge_tests/test_external_data.py | 16 ++++++++++++---- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/modelbench_tests/test_benchmark.py b/tests/modelbench_tests/test_benchmark.py index 9b683744..9c028412 100644 --- a/tests/modelbench_tests/test_benchmark.py +++ b/tests/modelbench_tests/test_benchmark.py @@ -53,7 +53,7 @@ def test_benchmark_definition_basics(): @pytest.mark.parametrize("prompt_set", PROMPT_SETS.keys()) -def test_benchmark_v1_definition_basics(prompt_set): +def test_benchmark_v1_definition_basics(prompt_set, fake_secrets): mbb = GeneralPurposeAiChatBenchmarkV1(Locale.EN_US, prompt_set) assert mbb.uid == f"general_purpose_ai_chat_benchmark-1.0-en_us-{prompt_set}-default" assert mbb.name() == "General Purpose Ai Chat Benchmark V 1" @@ -66,7 +66,7 @@ def test_benchmark_v1_definition_basics(prompt_set): assert hazard.hazard_key == hazard_key assert hazard.locale == Locale.EN_US assert hazard.prompt_set == prompt_set - assert prompt_set in hazard.tests(secrets={})[0].prompt_set_file_name + assert prompt_set in hazard.tests(secrets=fake_secrets)[0].prompt_set_file_name @pytest.mark.parametrize( diff --git a/tests/modelgauge_tests/data/sample_cache.sqlite b/tests/modelgauge_tests/data/sample_cache.sqlite index 248d66ba5649d80739edce32b5469e5e2586bc0b..15025f27bd5658a9f835ab49b9b694031b66e8a3 100644 GIT binary patch delta 31 mcmZojXh@il#1}8i%m0IciNBYDznA~=W{Do#R>qM@(Eu6 delta 31 mcmZojXh@il#Akn#m;VO?6MruQe=q;#&4LQD{F{6EixmK=)(QIn diff --git a/tests/modelgauge_tests/test_external_data.py b/tests/modelgauge_tests/test_external_data.py index 38297b72..8a9b62dc 100644 --- a/tests/modelgauge_tests/test_external_data.py +++ b/tests/modelgauge_tests/test_external_data.py @@ -9,14 +9,22 @@ from tenacity import wait_none +WebDataMockResponse = namedtuple("WebDataMockResponse", ("ok", "content")) GDriveFileToDownload = namedtuple("GDriveFileToDownload", ("id", "path")) -def test_web_data_download(mocker): - mock_download = mocker.patch("urllib.request.urlretrieve") +def test_web_data_download(mocker, tmpdir): + mock_download = mocker.patch("requests.get", return_value=WebDataMockResponse(ok=True, content=b"test")) web_data = WebData(source_url="http://example.com") - web_data.download("test.tgz") - mock_download.assert_called_once_with("http://example.com", "test.tgz", reporthook=ANY) + web_data.download(tmpdir / "file.txt") + mock_download.assert_called_once_with("http://example.com") + + +def test_web_data_download_with_headers(mocker, tmpdir): + mock_download = mocker.patch("requests.get", return_value=WebDataMockResponse(ok=True, content=b"test")) + web_data = WebData(source_url="http://example.com", headers={"token": "secret"}) + web_data.download(tmpdir / "file.txt") + mock_download.assert_called_once_with("http://example.com", headers={"token": "secret"}) def test_gdrive_data_download(mocker): From b30b03064faa4823c921e30f3d468b2e06401a48 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Mon, 18 Nov 2024 13:48:53 -0800 Subject: [PATCH 4/4] mypy --- src/modelgauge/external_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modelgauge/external_data.py b/src/modelgauge/external_data.py index 52cf1d27..72e57bad 100644 --- a/src/modelgauge/external_data.py +++ b/src/modelgauge/external_data.py @@ -1,4 +1,4 @@ -import requests +import requests # type: ignore import shutil import tempfile from abc import ABC, abstractmethod