Skip to content

Commit

Permalink
feat: copy qs from another dataset (#692)
Browse files Browse the repository at this point in the history
* feat: copy qs from another dataset

* fix tests

* respond to comments

---------

Co-authored-by: Grant <grant@kolena.io>
  • Loading branch information
nankolena and grant-Kolena authored Sep 25, 2024
1 parent d264bbe commit 88992cd
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 0 deletions.
1 change: 1 addition & 0 deletions kolena/_api/v1/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class Event(str, Enum):

# quality-standard
FETCH_QUALITY_STANDARD_RESULT = "sdk-quality-standard-result-fetched"
COPY_QUALITY_STANDARD_FROM_DATASET = "sdk-quality-standard-copied-from-dataset"

@dataclass(frozen=True)
class RecordEventRequest:
Expand Down
11 changes: 11 additions & 0 deletions kolena/_api/v2/quality_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,18 @@
# limitations under the License.
from enum import Enum

from kolena._utils.pydantic_v1.dataclasses import dataclass


class Path(str, Enum):
QUALITY_STANDARD = "quality-standard"
RESULT = "quality-standard/result"
COPY_FROM_DATASET = "quality-standard/copy-from-dataset"


@dataclass(frozen=True)
class CopyQualityStandardRequest:
dataset_id: int
source_dataset_id: int
include_metric_groups: bool = True
include_test_cases: bool = True
1 change: 1 addition & 0 deletions kolena/_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from kolena._experimental.quality_standard import copy_quality_standards_from_dataset
from kolena._experimental.quality_standard import download_quality_standard_result
56 changes: 56 additions & 0 deletions kolena/_experimental/quality_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from dataclasses import asdict
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union

import pandas as pd

from kolena._api.v1.event import EventAPI
from kolena._api.v2.quality_standard import CopyQualityStandardRequest
from kolena._api.v2.quality_standard import Path
from kolena._utils import krequests_v2 as krequests
from kolena._utils import log
from kolena._utils.instrumentation import with_event
from kolena.dataset.dataset import _load_dataset_metadata
from kolena.errors import IncorrectUsageError


def _format_quality_standard_result_df(quality_standard_result: dict) -> pd.DataFrame:
Expand Down Expand Up @@ -91,3 +98,52 @@ def download_quality_standard_result(
for model in models:
result_dfs.append(_download_quality_standard_result(dataset, [model], metric_groups, intersect_results))
return pd.concat(result_dfs, axis=1)


@with_event(event_name=EventAPI.Event.COPY_QUALITY_STANDARD_FROM_DATASET)
def copy_quality_standards_from_dataset(
dataset: str,
source_dataset: str,
include_metric_groups: bool = True,
include_test_cases: bool = True,
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""
Create a quality standard on a dataset by copying from a source dataset. Note that this operation will overwrite the
existing quality standards on the dataset if they exist.
:param dataset: The name of the dataset.
:param source_dataset: The name of the dataset from which the quality standards should be copied.
:param include_metric_groups: Optional flag to indicate whether to copy the metric groups from the source dataset.
:param include_test_cases: Optional flag to indicate whether to copy the test cases from the source dataset.
:return: A tuple of the created metric groups and test cases.
"""
if dataset == source_dataset:
raise IncorrectUsageError("source dataset and target dataset are the same")

if not include_test_cases and not include_metric_groups:
raise IncorrectUsageError("should include at least one of metric group or test case")

dataset_metadata = _load_dataset_metadata(dataset)
if not dataset_metadata:
raise IncorrectUsageError(f"The dataset with name '{dataset}' not found")
source_dataset_metadata = _load_dataset_metadata(source_dataset)
if not source_dataset_metadata:
raise IncorrectUsageError(f"The source dataset with name '{source_dataset}' not found")

request = CopyQualityStandardRequest(
dataset_metadata.id,
source_dataset_metadata.id,
include_metric_groups=include_metric_groups,
include_test_cases=include_test_cases,
)

response = krequests.put(
Path.COPY_FROM_DATASET,
json=asdict(request),
api_version="v2",
)
krequests.raise_for_status(response)

metric_groups = response.json().get("metric_groups", [])
test_cases = response.json().get("stratifications", [])
return metric_groups, test_cases
98 changes: 98 additions & 0 deletions tests/integration/_experimental/test_quality_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple

Expand All @@ -20,10 +22,12 @@
import pytest
from pandas.testing import assert_frame_equal

from kolena._experimental import copy_quality_standards_from_dataset
from kolena._experimental import download_quality_standard_result
from kolena.dataset import upload_dataset
from kolena.dataset.evaluation import _upload_results
from kolena.dataset.evaluation import EvalConfig
from kolena.errors import IncorrectUsageError
from tests.integration._experimental.helper import create_quality_standard
from tests.integration.helper import fake_locator
from tests.integration.helper import with_test_prefix
Expand Down Expand Up @@ -276,3 +280,97 @@ def test__download_quality_standard_result__union(
]
== waterloo_minimum
)


def test__copy_quality_standards_from_dataset__dataset_same_as_source() -> None:
dataset_name = with_test_prefix("test__copy_quality_standards_from_dataset__dataset_same_as_source")
source_dataset_name = dataset_name
with pytest.raises(IncorrectUsageError) as exc_info:
copy_quality_standards_from_dataset(dataset_name, source_dataset_name)
exc_info_value = str(exc_info.value)
assert "source dataset and target dataset are the same" in exc_info_value


def _assert_metric_groups_equal(metric_groups_1: List[Dict[str, Any]], metric_groups_2: List[Dict[str, Any]]) -> None:
assert len(metric_groups_1) == len(metric_groups_2)
for metric_group_1, metric_group_2 in zip(metric_groups_1, metric_groups_2):
assert metric_group_1["name"] == metric_group_2["name"]
assert len(metric_group_1["metrics"]) == len(metric_group_2["metrics"])
for metric_1, metric_2 in zip(metric_group_1["metrics"], metric_group_2["metrics"]):
assert metric_1["label"] == metric_2["label"]


def _assert_test_cases_equal(test_cases_list_1: List[Dict[str, Any]], test_cases_list_2: List[Dict[str, Any]]) -> None:
assert len(test_cases_list_1) == len(test_cases_list_2)
for test_cases_1, test_cases_2 in zip(test_cases_list_1, test_cases_list_2):
assert test_cases_1["name"] == test_cases_2["name"]
assert len(test_cases_1["test_cases"]) == len(test_cases_2["test_cases"])
for metric_1, metric_2 in zip(test_cases_1["test_cases"], test_cases_2["test_cases"]):
assert metric_1["name"] == metric_2["name"]


def test__copy_quality_standards_from_dataset(datapoints: pd.DataFrame) -> None:
source_dataset_name = with_test_prefix("test__copy_quality_standards_from_dataset__source_dataset")
dataset_name = with_test_prefix("test__copy_quality_standards_from_dataset__dataset")

upload_dataset(source_dataset_name, datapoints, id_fields=ID_FIELDS)
upload_dataset(dataset_name, datapoints, id_fields=ID_FIELDS)

quality_standards = dict(
name=with_test_prefix("test__copy_quality_standards_from_dataset__qs"),
stratifications=[
dict(
name=with_test_prefix("test__copy_quality_standards_from_dataset__test-case"),
stratify_fields=[dict(source="datapoint", field="city", values=["new york", "waterloo"])],
test_cases=[
dict(name="new york", stratification=[dict(value="new york")]),
dict(name="waterloo", stratification=[dict(value="waterloo")]),
],
),
],
metric_groups=[
dict(
name=with_test_prefix("test__copy_quality_standards_from_dataset__metric_group"),
metrics=[
dict(label="Max Score", source="result", aggregator="max", params=dict(key="score")),
dict(label="Min Score", source="result", aggregator="min", params=dict(key="score")),
],
),
],
version="1.0",
)
create_quality_standard(source_dataset_name, quality_standards)

# by default, should copy both metric groups and test cases
metric_groups, test_cases = copy_quality_standards_from_dataset(dataset_name, source_dataset_name)
_assert_metric_groups_equal(quality_standards["metric_groups"], metric_groups)
_assert_test_cases_equal(quality_standards["stratifications"], test_cases)

# exclude metric groups
metric_groups, test_cases = copy_quality_standards_from_dataset(
dataset_name,
source_dataset_name,
include_metric_groups=False,
)
assert metric_groups == []
_assert_test_cases_equal(quality_standards["stratifications"], test_cases)

# exclude test cases
metric_groups, test_cases = copy_quality_standards_from_dataset(
dataset_name,
source_dataset_name,
include_test_cases=False,
)
_assert_metric_groups_equal(quality_standards["metric_groups"], metric_groups)
assert test_cases == []

# cannot exclude both test cases and metric groups
with pytest.raises(IncorrectUsageError) as exc_info:
copy_quality_standards_from_dataset(
dataset_name,
source_dataset_name,
include_metric_groups=False,
include_test_cases=False,
)
exc_info_value = str(exc_info.value)
assert "should include at least one of metric groups or test cases" in exc_info_value

0 comments on commit 88992cd

Please sign in to comment.