Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
shunk031 committed Jun 25, 2024
1 parent f07d43c commit 85e1005
Showing 1 changed file with 13 additions and 42 deletions.
55 changes: 13 additions & 42 deletions tests/layout_utility_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import pathlib
from typing import List

import evaluate
import pytest
Expand All @@ -17,65 +18,35 @@ def metric_path(base_dir: str) -> str:


@pytest.fixture
def test_fixture_dir() -> pathlib.Path:
return pathlib.Path(__file__).parents[1] / "test_fixtures"


@pytest.fixture
def poster_width() -> int:
return 513


@pytest.fixture
def poster_height() -> int:
return 750


@pytest.fixture
def saliency_maps_1_dir(test_fixture_dir: pathlib.Path):
return test_fixture_dir / "PKU_PosterLayout" / "test" / "saliencymaps_pfpn"


@pytest.fixture
def saliency_maps_2_dir(test_fixture_dir: pathlib.Path):
return test_fixture_dir / "PKU_PosterLayout" / "test" / "saliencymaps_basnet"
def expected_score(is_CI: bool) -> float:
# https://github.com/PKU-ICST-MIPL/PosterLayout-CVPR2023/blob/main/output/results.txt#L2C14-L2C31
return 0.24395973228151718 if is_CI else 0.25410159915056757


def test_metric(
metric_path: str,
test_fixture_dir: pathlib.Path,
poster_predictions: torch.Tensor,
poster_gold_labels: torch.Tensor,
poster_width: int,
poster_height: int,
saliency_maps_1_dir: pathlib.Path,
saliency_maps_2_dir: pathlib.Path,
# https://github.com/PKU-ICST-MIPL/PosterLayout-CVPR2023/blob/main/output/results.txt#L2C14-L2C31
expected_score: float = 0.25410159915056757,
saliency_map_filepaths_1: List[pathlib.Path],
saliency_map_filepaths_2: List[pathlib.Path],
expected_score: float,
):
image_names = torch.load(test_fixture_dir / "poster_layout_test_order.pt")

saliency_map_filepaths_1 = [
saliency_maps_1_dir / name.replace(".", "_pred.") for name in image_names
]
saliency_map_filepaths_2 = [saliency_maps_2_dir / name for name in image_names]
assert len(saliency_map_filepaths_1) == len(saliency_map_filepaths_2)

# Convert pathlib.Path to str
saliency_map_filepaths_1 = [[str(path)] for path in saliency_map_filepaths_1]
saliency_map_filepaths_2 = [[str(path)] for path in saliency_map_filepaths_2]

# shape: (batch_size, max_elements, 4)
predictions = torch.load(test_fixture_dir / "poster_layout_boxes.pt")
# shape: (batch_size, max_elements, 1)
gold_labels = torch.load(test_fixture_dir / "poster_layout_clses.pt")
saliency_map_filepaths_1 = [[str(path)] for path in saliency_map_filepaths_1] # type: ignore
saliency_map_filepaths_2 = [[str(path)] for path in saliency_map_filepaths_2] # type: ignore

metric = evaluate.load(
path=metric_path,
canvas_width=poster_width,
canvas_height=poster_height,
)
metric.add_batch(
predictions=predictions,
gold_labels=gold_labels,
predictions=poster_predictions,
gold_labels=poster_gold_labels,
saliency_maps_1=saliency_map_filepaths_1,
saliency_maps_2=saliency_map_filepaths_2,
)
Expand Down

0 comments on commit 85e1005

Please sign in to comment.