Skip to content

Commit

Permalink
chore: validate results df has more than thresholded fields before up…
Browse files Browse the repository at this point in the history
…loading (#679)

* validate results df has more than thresholded fields before uploading

* fix name in test
  • Loading branch information
nankolena authored Sep 10, 2024
1 parent 8be77f5 commit 9789f8f
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 17 deletions.
11 changes: 8 additions & 3 deletions kolena/dataset/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ def validate_dataframe_ids(df: pd.DataFrame, id_fields: List[str]) -> None:
_validate_dataframe_ids_uniqueness(df, id_fields)


def validate_dataframe_have_other_columns_besides_ids(df: pd.DataFrame, id_fields: List[str]) -> None:
if set(df.columns) == set(id_fields):
raise InputValidationError("dataframe only contains id fields")
def validate_dataframe_columns(
df: pd.DataFrame,
id_fields: List[str],
thresholded_fields: Optional[List[str]] = None,
) -> None:
minimal_fields = set(id_fields).union(thresholded_fields) if thresholded_fields else set(id_fields)
if set(df.columns) == minimal_fields:
raise InputValidationError("dataframe only contains id fields and thresholded fields")
6 changes: 3 additions & 3 deletions kolena/dataset/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from kolena.dataset._common import COL_THRESHOLDED_OBJECT
from kolena.dataset._common import DEFAULT_SOURCES
from kolena.dataset._common import validate_batch_size
from kolena.dataset._common import validate_dataframe_have_other_columns_besides_ids
from kolena.dataset._common import validate_dataframe_columns
from kolena.dataset._common import validate_dataframe_ids
from kolena.dataset.dataset import _load_dataset_metadata
from kolena.dataset.dataset import _to_deserialized_dataframe
Expand Down Expand Up @@ -256,7 +256,7 @@ def _prepare_upload_results_request(
if isinstance(df_result_input, pd.DataFrame):
total_rows += df_result_input.shape[0]
validate_dataframe_ids(df_result_input, id_fields)
validate_dataframe_have_other_columns_besides_ids(df_result_input, id_fields)
validate_dataframe_columns(df_result_input, id_fields, thresholded_fields)
df_results = _process_result(config, df_result_input, id_fields, thresholded_fields)
upload_data_frame(df=df_results, load_uuid=load_uuid)

Expand All @@ -265,7 +265,7 @@ def _prepare_upload_results_request(
for df_result in df_result_input:
if not id_column_validated:
validate_dataframe_ids(df_result, id_fields)
validate_dataframe_have_other_columns_besides_ids(df_result, id_fields)
validate_dataframe_columns(df_result, id_fields, thresholded_fields)
id_column_validated = True
total_rows += df_result.shape[0]
df_results = _process_result(config, df_result, id_fields, thresholded_fields)
Expand Down
26 changes: 26 additions & 0 deletions tests/integration/dataset/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from kolena.dataset.dataset import _load_dataset_metadata
from kolena.dataset.evaluation import _upload_results
from kolena.errors import IncorrectUsageError
from kolena.errors import InputValidationError
from kolena.errors import NotFoundError
from tests.integration.dataset.test_dataset import batch_iterator
from tests.integration.helper import assert_frame_equal
Expand Down Expand Up @@ -490,6 +491,31 @@ def test__upload_results__thresholded() -> None:
assert_frame_equal(fetched_df_result, expected_df_result, result_columns)


def test__upload_results__only_id_and_thresholded_columns() -> None:
dataset_name = with_test_prefix(f"{__file__}::test__upload_results__only_id_and_thresholded_columns")
model_name = with_test_prefix(f"{__file__}::test__upload_results__only_id_and_thresholded_columns")
df_dp = get_df_dp()
dp_columns = [JOIN_COLUMN, "locator", "width", "height", "city"]
upload_dataset(dataset_name, df_dp[3:10][dp_columns], id_fields=ID_FIELDS)

records = [
dict(
user_dp_id=i,
bev=[dict(threshold=(j + 1) * 0.1, label="cat", foo=i + j) for j in range(3)],
)
for i in range(20)
]
df_result = pd.DataFrame(records)

with pytest.raises(InputValidationError):
_upload_results(
dataset_name,
model_name,
df_result,
thresholded_fields=["bev"],
)


def test__download_results__dataset_does_not_exist() -> None:
dataset_name = with_test_prefix(f"{__file__}::test__download_results__dataset_does_not_exist")
model_name = with_test_prefix(f"{__file__}::test__download_results__dataset_does_not_exist")
Expand Down
38 changes: 27 additions & 11 deletions tests/unit/dataset/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from typing import Optional

import pandas as pd
import pytest

from kolena.dataset._common import validate_dataframe_have_other_columns_besides_ids
from kolena.dataset._common import validate_dataframe_columns
from kolena.dataset._common import validate_dataframe_ids
from kolena.dataset._common import validate_id_fields
from kolena.errors import DuplicateDatapointIdError
Expand Down Expand Up @@ -101,23 +102,38 @@ def test__validate_dataframe_ids__duplicate_id() -> None:


@pytest.mark.parametrize(
"df, id_fields",
"df, id_fields, thresholded_fields",
[
(pd.DataFrame(dict(a=[1, 2, 3], b=[1, 2, 1])), ["a"]),
(pd.DataFrame({"a.text": [1, 2, 3], "b.text": [1, 2, 1]}), ["a.text"]),
(pd.DataFrame(dict(a=[1, 2, 3], b=[1, 2, 1])), ["a"], None),
(pd.DataFrame({"a.text": [1, 2, 3], "b.text": [1, 2, 1]}), ["a.text"], None),
(pd.DataFrame({"a.text": [1, 2, 3], "b.text": [1, 2, 1]}), ["a.text"], []),
(
pd.DataFrame({"a.text": [1, 2, 3], "b.text": [1, 2, 1], "threshold": [0.1, 0.2, 0.3]}),
["a.text"],
["threshold"],
),
],
)
def test__validate_dataframe_have_other_columns_besides_ids(df: pd.DataFrame, id_fields: List[str]) -> None:
validate_dataframe_have_other_columns_besides_ids(df, id_fields)
def test__validate_dataframe_columns_besides_ids(
df: pd.DataFrame,
id_fields: List[str],
thresholded_fields: Optional[List[str]],
) -> None:
validate_dataframe_columns(df, id_fields, thresholded_fields)


@pytest.mark.parametrize(
"df, id_fields",
"df, id_fields, thresholded_fields",
[
(pd.DataFrame(dict(a=[1, 2, 3], b=[1, 2, 1])), ["a", "b"]),
(pd.DataFrame({"a.text": [1, 2, 3], "b.text": [1, 2, 1]}), ["a.text", "b.text"]),
(pd.DataFrame(dict(a=[1, 2, 3], b=[1, 2, 1])), ["a", "b"], None),
(pd.DataFrame({"a.text": [1, 2, 3], "b.text": [1, 2, 1]}), ["a.text", "b.text"], None),
(pd.DataFrame({"a.text": [1, 2, 3], "b.text": [1, 2, 1]}), ["a.text"], ["b.text"]),
],
)
def test__validate_dataframe_have_other_columns_besides_ids__error(df: pd.DataFrame, id_fields: List[str]) -> None:
def test__validate_dataframe_columns__error(
df: pd.DataFrame,
id_fields: List[str],
thresholded_fields: Optional[List[str]],
) -> None:
with pytest.raises(InputValidationError):
validate_dataframe_have_other_columns_besides_ids(df, id_fields)
validate_dataframe_columns(df, id_fields, thresholded_fields)

0 comments on commit 9789f8f

Please sign in to comment.