Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

COCO mAP metric #2901

Open
wants to merge 56 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
34a1a3f
Keep only cocomap-related changes
sadra-barikbin May 17, 2023
24fe980
Some improvements
sadra-barikbin May 28, 2023
e2ac8ee
Update docs
sadra-barikbin May 28, 2023
e4683de
Merge branch 'master' into cocomap
sadra-barikbin May 28, 2023
7cf53e1
Fix a bug in docs
sadra-barikbin May 29, 2023
4aa9c5d
Fix a tiny bug related to allgather
sadra-barikbin Jun 15, 2023
950c388
Fix a few bugs
sadra-barikbin Jun 16, 2023
9f5f796
Redesign code:
sadra-barikbin Jun 16, 2023
ffb1ba4
Merge branch 'master' into cocomap
sadra-barikbin Jun 16, 2023
65cdd08
Remove all_gather with different shape
sadra-barikbin Jun 17, 2023
e54af52
Merge branch 'master' into cocomap
sadra-barikbin Jun 21, 2023
aac2e55
Add test for all_gather_with_different_shape func
sadra-barikbin Jun 21, 2023
4cf3972
Merge branch 'master' into cocomap
vfdev-5 Jun 21, 2023
6070e18
A few improvements
sadra-barikbin Aug 23, 2023
aa83e60
Merge remote-tracking branch 'upstream/cocomap' into cocomap
sadra-barikbin Aug 23, 2023
deebbde
Add an output transform
sadra-barikbin Aug 31, 2023
62ca5fb
Add a test for the output_transform
sadra-barikbin Aug 31, 2023
418fcf4
Remove 'flavor' because all DeciAI, Ultralytics, Detectron and pycoco…
sadra-barikbin Sep 1, 2023
5fea0cd
Merge branch 'master' into cocomap
sadra-barikbin Sep 1, 2023
79fa1e2
Revert Metric change and a few bug fix
sadra-barikbin Sep 10, 2023
26c96b8
A tiny improvement in local variable names
sadra-barikbin Sep 15, 2023
d18f793
Merge branch 'master' into cocomap
sadra-barikbin Sep 15, 2023
a361ca8
Add max_dep and area_range
sadra-barikbin Dec 4, 2023
ce48583
some improvements
sadra-barikbin Jun 28, 2024
cf02dc0
Improvement in code
sadra-barikbin Jul 11, 2024
1593dfb
Some improvements
sadra-barikbin Jul 12, 2024
e425e12
Fix a bug; Some improvements; Improve docs
sadra-barikbin Jul 16, 2024
bb15f0f
Fix metrics.rst
sadra-barikbin Jul 16, 2024
a184ba5
Merge branch 'master' into cocomap
sadra-barikbin Jul 16, 2024
6fcc97f
Remove @override which is for 3.12
sadra-barikbin Jul 16, 2024
120c755
Fix mypy issues
sadra-barikbin Jul 16, 2024
7c26d08
Fix two tests
sadra-barikbin Jul 16, 2024
c3c4a82
Fix a typo in tests
sadra-barikbin Jul 16, 2024
2405937
Fix dist tests
sadra-barikbin Jul 16, 2024
9b3100d
Merge branch 'master' into cocomap
sadra-barikbin Sep 3, 2024
356f618
Add common obj. det. metrics
sadra-barikbin Sep 3, 2024
bbfc4c7
Merge branch 'master' into cocomap
sadra-barikbin Sep 3, 2024
cb6a328
Change an annotation for the sake of M1 python3.8
sadra-barikbin Sep 3, 2024
248fe89
Use if check on torch.double usages for MPS backend
sadra-barikbin Sep 3, 2024
8bfb802
Fix a typo
sadra-barikbin Sep 4, 2024
4038c2b
Fix a bug related to tensors on same devices
sadra-barikbin Sep 4, 2024
4b6afdd
Fix a bug related to MPS and torch.double
sadra-barikbin Sep 4, 2024
d0e82b3
Fix a bug related to MPS
sadra-barikbin Sep 4, 2024
085e0df
Fix a bug related to MPS
sadra-barikbin Sep 4, 2024
3658f95
Fix a bug related to MPS
sadra-barikbin Sep 4, 2024
0444933
Resolve MPS's lack of cummax
sadra-barikbin Sep 4, 2024
c433718
Revert MPS fallback
sadra-barikbin Sep 4, 2024
dacf407
Apply comments
sadra-barikbin Sep 4, 2024
67454c3
Merge branch 'master' into cocomap
sadra-barikbin Sep 5, 2024
67e38c4
Revert unnecessary changes
sadra-barikbin Sep 5, 2024
978791b
Merge branch 'master' into cocomap
vfdev-5 Sep 9, 2024
7b43c69
Apply review comments
sadra-barikbin Sep 20, 2024
4d3fc57
Merge branch 'master' into cocomap
sadra-barikbin Sep 20, 2024
479d1b7
Merge branch 'master' into cocomap
sadra-barikbin Sep 29, 2024
954d130
Skip MPS on test_integraion as well
sadra-barikbin Sep 29, 2024
d2978cf
Merge branch 'master' into cocomap
sadra-barikbin Oct 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -332,13 +332,17 @@ Complete list of metrics
Frequency
Loss
MeanAbsoluteError
MeanAveragePrecision
MeanPairwiseDistance
MeanSquaredError
metric.Metric
metric_group.MetricGroup
metrics_lambda.MetricsLambda
MultiLabelConfusionMatrix
MutualInformation
ObjectDetectionAvgPrecisionRecall
CommonObjectDetectionMetrics
vision.object_detection_average_precision_recall.coco_tensor_list_to_dict_list
precision.Precision
PSNR
recall.Recall
Expand Down
10 changes: 10 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ignite.metrics.loss import Loss
from ignite.metrics.maximum_mean_discrepancy import MaximumMeanDiscrepancy
from ignite.metrics.mean_absolute_error import MeanAbsoluteError
from ignite.metrics.mean_average_precision import MeanAveragePrecision
from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance
from ignite.metrics.mean_squared_error import MeanSquaredError
from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage
Expand All @@ -37,6 +38,11 @@
from ignite.metrics.running_average import RunningAverage
from ignite.metrics.ssim import SSIM
from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy
from ignite.metrics.vision.object_detection_average_precision_recall import (
coco_tensor_list_to_dict_list,
CommonObjectDetectionMetrics,
ObjectDetectionAvgPrecisionRecall,
)

__all__ = [
"Metric",
Expand Down Expand Up @@ -86,4 +92,8 @@
"PrecisionRecallCurve",
"RocCurve",
"ROC_AUC",
"MeanAveragePrecision",
"ObjectDetectionAvgPrecisionRecall",
"CommonObjectDetectionMetrics",
"coco_tensor_list_to_dict_list",
]
385 changes: 385 additions & 0 deletions ignite/metrics/mean_average_precision.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def _is_list_of_tensors_or_numbers(x: Sequence[Union[torch.Tensor, float]]) -> b
return isinstance(x, Sequence) and all([isinstance(t, (torch.Tensor, Number)) for t in x])


def _to_batched_tensor(x: Union[torch.Tensor, float], device: Optional[torch.device] = None) -> torch.Tensor:
def _to_batched_tensor(x: Union[torch.Tensor, Number], device: Optional[torch.device] = None) -> torch.Tensor:
if isinstance(x, torch.Tensor):
return x.unsqueeze(dim=0)
return torch.tensor([x], device=device)
4 changes: 2 additions & 2 deletions ignite/metrics/metric_group.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, Sequence
from typing import Any, Callable, Dict, Sequence, Tuple

import torch

Expand Down Expand Up @@ -36,7 +36,7 @@ class MetricGroup(Metric):
state.metrics["eval_metrics"]
"""

_state_dict_all_req_keys = ("metrics",)
_state_dict_all_req_keys: Tuple[str, ...] = ("metrics",)

def __init__(self, metrics: Dict[str, Metric], output_transform: Callable = lambda x: x):
self.metrics = metrics
Expand Down
3 changes: 3 additions & 0 deletions ignite/metrics/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ignite.metrics.vision.object_detection_average_precision_recall import ObjectDetectionAvgPrecisionRecall

__all__ = ["ObjectDetectionAvgPrecisionRecall"]
482 changes: 482 additions & 0 deletions ignite/metrics/vision/object_detection_average_precision_recall.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pynvml
clearml
scikit-image
py-rouge
pycocotools
# temporary fix for python=3.12 and v3.8.1
# nltk
git+https://github.com/nltk/nltk@aba99c8
Expand Down
205 changes: 205 additions & 0 deletions tests/ignite/metrics/test_mean_average_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import numpy as np
import pytest
import torch
from sklearn.metrics import average_precision_score, precision_recall_curve

from ignite import distributed as idist
from ignite.engine import Engine
from ignite.metrics import MeanAveragePrecision
from ignite.utils import manual_seed, to_onehot

manual_seed(41)


def test_wrong_input():
with pytest.raises(ValueError, match="rec_thresholds should be a one-dimensional tensor or a sequence of floats"):
MeanAveragePrecision(rec_thresholds=torch.zeros((2, 2)))

with pytest.raises(TypeError, match="rec_thresholds should be a sequence of floats or a tensor"):
MeanAveragePrecision(rec_thresholds={0, 0.2, 0.4, 0.6, 0.8})

with pytest.raises(ValueError, match="Wrong `class_mean` parameter"):
MeanAveragePrecision(class_mean="samples")

with pytest.raises(ValueError, match="rec_thresholds values should be between 0 and 1"):
MeanAveragePrecision(rec_thresholds=(0.0, 0.5, 1.0, 1.5))

metric = MeanAveragePrecision()
with pytest.raises(RuntimeError, match="Metric could not be computed without any update method call"):
metric.compute()


def test_wrong_classification_input():
metric = MeanAveragePrecision()

with pytest.raises(TypeError, match="`y_pred` should be a float tensor"):
metric.update((torch.tensor([0, 1, 0]), torch.tensor([1, 0, 1])))

metric = MeanAveragePrecision()
with pytest.warns(RuntimeWarning, match="`y` should be of dtype long when entry type is multiclass"):
metric.update((torch.tensor([[0.5, 0.4, 0.1]]), torch.tensor([2.0])))

with pytest.raises(ValueError, match="y_pred contains fewer classes than y"):
metric.update((torch.tensor([[0.5, 0.4, 0.1]]), torch.tensor([3])))


def test__prepare_output():
metric = MeanAveragePrecision()

metric._type = "binary"
scores, y = metric._prepare_output((torch.rand((5, 4, 3, 2)), torch.randint(0, 2, (5, 4, 3, 2)).bool()))
assert scores.shape == y.shape == (1, 120)

metric._type = "multiclass"
scores, y = metric._prepare_output((torch.rand((5, 4, 3, 2)), torch.randint(0, 4, (5, 3, 2))))
assert scores.shape == (4, 30) and y.shape == (30,)

metric._type = "multilabel"
scores, y = metric._prepare_output((torch.rand((5, 4, 3, 2)), torch.randint(0, 2, (5, 4, 3, 2)).bool()))
assert scores.shape == y.shape == (4, 30)


def test_update():
metric = MeanAveragePrecision()
assert len(metric._y_pred) == len(metric._y_true) == 0
metric.update((torch.rand((5, 4)), torch.randint(0, 2, (5, 4)).bool()))
assert len(metric._y_pred) == len(metric._y_true) == 1


def test__compute_recall_and_precision():
m = MeanAveragePrecision()

scores = torch.rand((50,))
y_true = torch.randint(0, 2, (50,)).bool()
precision, recall, _ = precision_recall_curve(y_true.numpy(), scores.numpy())
P = y_true.sum(dim=-1)
ignite_recall, ignite_precision = m._compute_recall_and_precision(y_true, scores, P)
assert (ignite_recall.squeeze().flip(0).numpy() == recall[:-1]).all()
assert (ignite_precision.squeeze().flip(0).numpy() == precision[:-1]).all()

# When there's no actual positive. Numpy expectedly raises warning.
scores = torch.rand((50,))
y_true = torch.zeros((50,)).bool()
precision, recall, _ = precision_recall_curve(y_true.numpy(), scores.numpy())
P = torch.tensor(0)
ignite_recall, ignite_precision = m._compute_recall_and_precision(y_true, scores, P)
assert (ignite_recall.flip(0).numpy() == recall[:-1]).all()
assert (ignite_precision.flip(0).numpy() == precision[:-1]).all()


def test__compute_average_precision():
m = MeanAveragePrecision()

# Binary data
scores = np.random.rand(50)
y_true = np.random.randint(0, 2, 50)
ap = average_precision_score(y_true, scores)
precision, recall, _ = precision_recall_curve(y_true, scores)
ignite_ap = m._compute_average_precision(
torch.from_numpy(recall[:-1]).flip(-1), torch.from_numpy(precision[:-1]).flip(-1)
)
assert np.allclose(ignite_ap.item(), ap)

# Multilabel data
scores = np.random.rand(50, 5)
y_true = np.random.randint(0, 2, (50, 5))
ap = average_precision_score(y_true, scores, average=None)
ignite_ap = []
for cls in range(scores.shape[1]):
precision, recall, _ = precision_recall_curve(y_true[:, cls], scores[:, cls])
ignite_ap.append(
m._compute_average_precision(
torch.from_numpy(recall[:-1]).flip(-1), torch.from_numpy(precision[:-1]).flip(-1)
).item()
)
ignite_ap = np.array(ignite_ap)
assert np.allclose(ignite_ap, ap)


def test_compute_binary_data():
m = MeanAveragePrecision()
scores = torch.rand((130,))
y_true = torch.randint(0, 2, (130,))

m.update((scores[:50], y_true[:50]))
m.update((scores[50:], y_true[50:]))
ignite_map = m.compute()

map = average_precision_score(y_true.numpy(), scores.numpy())

assert np.allclose(ignite_map, map)


@pytest.mark.parametrize("class_mean", [None, "macro", "micro", "weighted"])
def test_compute_nonbinary_data(class_mean):
scores = torch.rand((130, 5, 2, 2))
sklearn_scores = scores.transpose(1, -1).reshape(-1, 5).numpy()

# Multiclass
m = MeanAveragePrecision(class_mean=class_mean)
y_true = torch.randint(0, 5, (130, 2, 2))
m.update((scores[:50], y_true[:50]))
m.update((scores[50:], y_true[50:]))
ignite_map = m.compute().numpy()

y_true = to_onehot(y_true, 5).transpose(1, -1).reshape(-1, 5).numpy()
sklearn_map = average_precision_score(y_true, sklearn_scores, average=class_mean)

assert np.allclose(sklearn_map, ignite_map)

# Multilabel
m = MeanAveragePrecision(is_multilabel=True, class_mean=class_mean)
y_true = torch.randint(0, 2, (130, 5, 2, 2)).bool()
m.update((scores[:50], y_true[:50]))
m.update((scores[50:], y_true[50:]))
ignite_map = m.compute().numpy()

y_true = y_true.transpose(1, -1).reshape(-1, 5).numpy()
sklearn_map = average_precision_score(y_true, sklearn_scores, average=class_mean)

assert np.allclose(sklearn_map, ignite_map)


@pytest.mark.parametrize("data_type", ["binary", "multiclass", "multilabel"])
def test_distrib_integration(distributed, data_type):
rank = idist.get_rank()
world_size = idist.get_world_size()
device = idist.device()

def _test(metric_device):
def update(_, i):
return (
y_preds[(2 * rank + i) * 10 : (2 * rank + i + 1) * 10],
y_true[(2 * rank + i) * 10 : (2 * rank + i + 1) * 10],
)

engine = Engine(update)
mAP = MeanAveragePrecision(is_multilabel=data_type == "multilabel", device=metric_device)
mAP.attach(engine, "mAP")

y_true_size = (10 * 2 * world_size, 3, 2) if data_type != "multilabel" else (10 * 2 * world_size, 4, 3, 2)
y_true = torch.randint(0, 4 if data_type == "multiclass" else 2, size=y_true_size).to(device)
y_preds_size = (10 * 2 * world_size, 4, 3, 2) if data_type != "binary" else (10 * 2 * world_size, 3, 2)
y_preds = torch.rand(y_preds_size).to(device)

engine.run(range(2), max_epochs=1)
assert "mAP" in engine.state.metrics

if data_type == "multiclass":
y_true = to_onehot(y_true, 4)

if data_type == "binary":
y_true = y_true.view(-1)
y_preds = y_preds.view(-1)
else:
y_true = y_true.transpose(1, -1).reshape(-1, 4)
y_preds = y_preds.transpose(1, -1).reshape(-1, 4)

sklearn_mAP = average_precision_score(y_true.numpy(), y_preds.numpy())
assert np.allclose(sklearn_mAP, engine.state.metrics["mAP"])

metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(idist.device())
for metric_device in metric_devices:
_test(metric_device)
Empty file.
Loading
Loading