From f76f55e29b442ef4ac0c9c5c68037fa9a27743f4 Mon Sep 17 00:00:00 2001 From: Shunsuke KITADA Date: Wed, 26 Jun 2024 13:09:39 +0900 Subject: [PATCH] try to improve tests --- README.md | 2 +- metrics/layout_occlusion/layout-occlusion.py | 13 ++- tests/conftest.py | 20 +++++ tests/layout_occulusion_test.py | 88 ++++++++++++++++++++ 4 files changed, 121 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 123d004..5c08170 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ A collection of metrics to evaluate layout generation that can be easily used in 🤗 huggingface [evaluate](https://huggingface.co/docs/evaluate/index). | 📊 Metric | 🤗 Space | 📝 Paper | -|:---------|:---------|----------| +|:---------:|:---------|:----------| | [![FID](https://github.com/creative-graphic-design/huggingface-evaluate_layout-metrics/actions/workflows/layout_generative_model_scores.yaml/badge.svg)](https://github.com/creative-graphic-design/huggingface-evaluate_layout-metrics/actions/workflows/layout_generative_model_scores.yaml) | [`creative-graphic-design/layout-generative-model-scores`](https://huggingface.co/spaces/creative-graphic-design/layout-generative-model-scores) | [[Heusel+ NeurIPS'17](https://arxiv.org/abs/1706.08500)], [[Naeem+ ICML'20](https://arxiv.org/abs/2002.09797)] | | [![Max. IoU](https://github.com/creative-graphic-design/huggingface-evaluate_layout-metrics/actions/workflows/layout_maximum_iou.yaml/badge.svg)](https://github.com/creative-graphic-design/huggingface-evaluate_layout-metrics/actions/workflows/layout_maximum_iou.yaml) | [`creative-graphic-design/layout-maximum-iou`](https://huggingface.co/spaces/creative-graphic-design/layout-maximum-iou) | [[Kikuchi+ ACMMM'21](https://arxiv.org/abs/2108.00871)] | | [![Avg. IoU](https://github.com/creative-graphic-design/huggingface-evaluate_layout-metrics/actions/workflows/layout_average_iou.yaml/badge.svg)](https://github.com/creative-graphic-design/huggingface-evaluate_layout-metrics/actions/workflows/layout_average_iou.yaml) | [`creative-graphic-design/layout-average-iou`](https://huggingface.co/spaces/creative-graphic-design/layout-average-iou) | [[Arroyo+ CVPR'21](https://arxiv.org/abs/2104.02416)], [[Kong+ ECCV'22](https://arxiv.org/abs/2112.05112)] | diff --git a/metrics/layout_occlusion/layout-occlusion.py b/metrics/layout_occlusion/layout-occlusion.py index 564c205..0d816ca 100644 --- a/metrics/layout_occlusion/layout-occlusion.py +++ b/metrics/layout_occlusion/layout-occlusion.py @@ -12,7 +12,18 @@ """ _KWARGS_DESCRIPTION = """\ -FIXME +Args: + predictions (`list` of `lists` of `float`): A list of lists of floats representing bounding boxes. + gold_labels (`list` of `lists` of `int`): A list of lists of integers representing class labels. + saliency_maps_1 (`list` of `str`): A list of strings representing path to saliency maps 1. + saliency_maps_2 (`list` of `str`): A list of strings representing path to saliency maps 2. + +Returns: + float: Average saliency of areas covered by elements. Lower values are generally better (in 0.0 - 1.0 range). + +Examples: + Examples 1: Single processing + >>> metric = evaluate.load("creative-graphic-design/layout-occlusion") """ _CITATION = """\ diff --git a/tests/conftest.py b/tests/conftest.py index e48dded..7b1b90e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,6 +26,26 @@ def poster_height() -> int: return 750 +@pytest.fixture +def batch_size() -> int: + return 512 + + +@pytest.fixture +def max_layout_elements() -> int: + return 25 + + +@pytest.fixture +def num_coordinates() -> int: + return 4 + + +@pytest.fixture +def num_class_labels() -> int: + return 10 + + @pytest.fixture def is_CI() -> bool: return bool(os.environ.get("CI", False)) diff --git a/tests/layout_occulusion_test.py b/tests/layout_occulusion_test.py index a1c94c1..d24009d 100644 --- a/tests/layout_occulusion_test.py +++ b/tests/layout_occulusion_test.py @@ -1,10 +1,13 @@ +import io import os import pathlib from typing import List import evaluate +import numpy as np import pytest import torch +from PIL import Image @pytest.fixture @@ -23,6 +26,91 @@ def expected_score(is_CI: bool) -> float: return 0.15746160746433283 if is_CI else 0.20880194364379892 +def create_in_memory_saliency_maps( + batch_size: int, poster_width: int, poster_height: int +): + def _create_random_gaussian_image(w, h): + """Create a random black image with a white Gaussian.""" + # Generate random parameters for Gaussian + x, y = np.meshgrid(np.linspace(-1, 1, w), np.linspace(-1, 1, h)) + + # Generate random center for Gaussian + mu_x = np.random.rand() - 0.5 + mu_y = np.random.rand() - 0.5 + d = np.sqrt((x - mu_x) ** 2 + (y - mu_y) ** 2) + + # Generate random sigma for Gaussian + sigma = 0.2 + np.random.rand() * 0.4 + + g = np.exp(-((d) ** 2 / (2.0 * sigma**2))) + + # Create a new image with black background + image = Image.fromarray(g * 255).convert("L") + + return image + + def _create_in_memory_saliency_maps(batch_size: int): + images = [ + _create_random_gaussian_image(w=poster_width, h=poster_height) + for _ in range(batch_size) + ] + + image_filepaths = [] + for image in images: + image_io = io.BytesIO() + image.save(image_io, format="PNG") + image_io.seek(0) + image_filepaths.append(image_io) + return image_filepaths + + return _create_in_memory_saliency_maps(batch_size) + + +def test_metric_random( + metric_path: str, + batch_size: int, + poster_width: int, + poster_height: int, + max_layout_elements: int, + num_coordinates: int, + num_class_labels: int, +): + metric = evaluate.load( + path=metric_path, + canvas_width=poster_width, + canvas_height=poster_height, + ) + batch_predictions = np.random.rand( + batch_size, + max_layout_elements, + num_coordinates, + ) + batch_gold_labels = np.random.randint( + num_class_labels, + size=( + batch_size, + max_layout_elements, + 1, + ), + ) + metric.add_batch( + predictions=batch_predictions, + gold_labels=batch_gold_labels, + saliency_maps_1=create_in_memory_saliency_maps( + batch_size=batch_size, + poster_width=poster_width, + poster_height=poster_height, + ), + saliency_maps_2=create_in_memory_saliency_maps( + batch_size=batch_size, + poster_width=poster_width, + poster_height=poster_height, + ), + ) + score = metric.compute() + assert score is not None + + def test_metric( metric_path: str, poster_predictions: torch.Tensor,