diff --git a/src/modelgauge/external_data.py b/src/modelgauge/external_data.py index 3aef331b..72e57bad 100644 --- a/src/modelgauge/external_data.py +++ b/src/modelgauge/external_data.py @@ -1,15 +1,14 @@ +import requests # type: ignore import shutil import tempfile -import urllib.request from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional +from typing import Dict, Optional import gdown # type: ignore from tenacity import retry, stop_after_attempt, wait_exponential from modelgauge.data_packing import DataDecompressor, DataUnpacker -from modelgauge.general import UrlRetrieveProgressBar @dataclass(frozen=True, kw_only=True) @@ -31,6 +30,7 @@ class WebData(ExternalData): """External data that can be trivially downloaded using wget.""" source_url: str + headers: Optional[Dict] = None @retry( stop=stop_after_attempt(5), @@ -38,11 +38,17 @@ class WebData(ExternalData): reraise=True, ) def download(self, location): - urllib.request.urlretrieve( - self.source_url, - location, - reporthook=UrlRetrieveProgressBar(self.source_url), - ) + if self.headers: + response = requests.get(self.source_url, headers=self.headers) + else: + response = requests.get(self.source_url) + if response.ok: + with open(location, "wb") as f: + f.write(response.content) + else: + raise RuntimeError( + f"failed to fetch {self.source_url} with headers={self.headers}.\nResponse status: {response.status_code}: {response.text}" + ) @dataclass(frozen=True, kw_only=True) diff --git a/src/modelgauge/tests/safe_v1.py b/src/modelgauge/tests/safe_v1.py index 0f6b63c7..492306b9 100644 --- a/src/modelgauge/tests/safe_v1.py +++ b/src/modelgauge/tests/safe_v1.py @@ -1,7 +1,7 @@ import csv from abc import ABC from enum import Enum -from typing import Dict, List, Mapping +from typing import Dict, List, Mapping, Optional from modelgauge.aggregations import get_measurement_stats_by_key from modelgauge.annotator_set import AnnotatorSet @@ -10,6 +10,7 @@ from modelgauge.dependency_helper import DependencyHelper from modelgauge.external_data import ExternalData, WebData from modelgauge.prompt import SUTOptions, TextPrompt +from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription from modelgauge.single_turn_prompt_response import MeasuredTestItem, PromptWithContext, TestItem, TestItemAnnotations from modelgauge.sut_capabilities import AcceptsTextPrompt from modelgauge.test_decorator import modelgauge_test @@ -18,10 +19,20 @@ PROMPT_SETS = { "practice": "airr_official_1.0_practice_prompt_set_release", - "official": "mock_airr_official_1.0_heldback_prompt_set_release", # TODO: SWAP OUT FOR REAL FILE after its uploaded + "official": "airr_official_1.0_heldback_prompt_set_release", } +class ModellabFileDownloadToken(RequiredSecret): + @classmethod + def description(cls) -> SecretDescription: + return SecretDescription( + scope="modellab_files", + key="token", + instructions="Please ask MLCommons admin for permission.", + ) + + class Locale(Enum): EN_US = "en_US" # English, United States FR_FR = "fr_FR" # French, France @@ -92,7 +103,13 @@ def create_uid(hazard: str, locale: str | Locale, prompt_set: str, evaluator=Non return f"safe-{hazard}-{locale}-{prompt_set}-1.0{postfix}".lower() def __init__( - self, uid: str, hazard: str, locale: Locale, persona_types: List[SafePersonasVersion1], prompt_set: str + self, + uid: str, + hazard: str, + locale: Locale, + persona_types: List[SafePersonasVersion1], + prompt_set: str, + token: Optional[ModellabFileDownloadToken] = None, ): self._check_annotators() super().__init__(uid) @@ -106,6 +123,7 @@ def __init__( self.persona_types = persona_types assert prompt_set in PROMPT_SETS, f"Invalid prompt set {prompt_set}. Must be one of {PROMPT_SETS.keys()}." self.prompt_set_file_name = PROMPT_SETS[prompt_set] + self.token = token @classmethod def _check_annotators(cls): @@ -115,10 +133,14 @@ def _check_annotators(cls): def get_dependencies(self) -> Mapping[str, ExternalData]: modellab_base_download_url = "https://modellab.modelmodel.org/files/download" + headers = None + if self.token is not None: + headers = {"auth-token": self.token.value} # Only one dependency. return { self.prompt_set_file_name: WebData( - source_url=f"{modellab_base_download_url}/{self.prompt_set_file_name}.csv" + source_url=f"{modellab_base_download_url}/{self.prompt_set_file_name}.csv", + headers=headers, ) } @@ -204,7 +226,10 @@ def register_tests(cls, evaluator=None): test_uid = BaseSafeTestVersion1.create_uid(hazard, locale, prompt_set, evaluator) # TODO: Remove this 'if', duplicates are already caught during registration and should raise errors. if not test_uid in TESTS.keys(): - TESTS.register(cls, test_uid, hazard, locale, ALL_PERSONAS, prompt_set) + token = None + if prompt_set == "official": + token = InjectSecret(ModellabFileDownloadToken) + TESTS.register(cls, test_uid, hazard, locale, ALL_PERSONAS, prompt_set, token) # default llama guard annotator, always diff --git a/tests/modelbench_tests/test_benchmark.py b/tests/modelbench_tests/test_benchmark.py index 9b683744..9c028412 100644 --- a/tests/modelbench_tests/test_benchmark.py +++ b/tests/modelbench_tests/test_benchmark.py @@ -53,7 +53,7 @@ def test_benchmark_definition_basics(): @pytest.mark.parametrize("prompt_set", PROMPT_SETS.keys()) -def test_benchmark_v1_definition_basics(prompt_set): +def test_benchmark_v1_definition_basics(prompt_set, fake_secrets): mbb = GeneralPurposeAiChatBenchmarkV1(Locale.EN_US, prompt_set) assert mbb.uid == f"general_purpose_ai_chat_benchmark-1.0-en_us-{prompt_set}-default" assert mbb.name() == "General Purpose Ai Chat Benchmark V 1" @@ -66,7 +66,7 @@ def test_benchmark_v1_definition_basics(prompt_set): assert hazard.hazard_key == hazard_key assert hazard.locale == Locale.EN_US assert hazard.prompt_set == prompt_set - assert prompt_set in hazard.tests(secrets={})[0].prompt_set_file_name + assert prompt_set in hazard.tests(secrets=fake_secrets)[0].prompt_set_file_name @pytest.mark.parametrize( diff --git a/tests/modelgauge_tests/data/sample_cache.sqlite b/tests/modelgauge_tests/data/sample_cache.sqlite index 248d66ba..15025f27 100644 Binary files a/tests/modelgauge_tests/data/sample_cache.sqlite and b/tests/modelgauge_tests/data/sample_cache.sqlite differ diff --git a/tests/modelgauge_tests/test_external_data.py b/tests/modelgauge_tests/test_external_data.py index 38297b72..8a9b62dc 100644 --- a/tests/modelgauge_tests/test_external_data.py +++ b/tests/modelgauge_tests/test_external_data.py @@ -9,14 +9,22 @@ from tenacity import wait_none +WebDataMockResponse = namedtuple("WebDataMockResponse", ("ok", "content")) GDriveFileToDownload = namedtuple("GDriveFileToDownload", ("id", "path")) -def test_web_data_download(mocker): - mock_download = mocker.patch("urllib.request.urlretrieve") +def test_web_data_download(mocker, tmpdir): + mock_download = mocker.patch("requests.get", return_value=WebDataMockResponse(ok=True, content=b"test")) web_data = WebData(source_url="http://example.com") - web_data.download("test.tgz") - mock_download.assert_called_once_with("http://example.com", "test.tgz", reporthook=ANY) + web_data.download(tmpdir / "file.txt") + mock_download.assert_called_once_with("http://example.com") + + +def test_web_data_download_with_headers(mocker, tmpdir): + mock_download = mocker.patch("requests.get", return_value=WebDataMockResponse(ok=True, content=b"test")) + web_data = WebData(source_url="http://example.com", headers={"token": "secret"}) + web_data.download(tmpdir / "file.txt") + mock_download.assert_called_once_with("http://example.com", headers={"token": "secret"}) def test_gdrive_data_download(mocker):