Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeCamp #1503] Add InstanceSegMetric to MMEval #70

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmeval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .end_point_error import EndPointError
from .f_metric import F1Metric
from .hmean_iou import HmeanIoU
from .instance_seg import InstanceSeg
from .mae import MAE
from .mean_iou import MeanIoU
from .mse import MSE
Expand All @@ -25,5 +26,5 @@
'F1Metric', 'HmeanIoU', 'SingleLabelMetric', 'COCODetectionMetric',
'PCKAccuracy', 'MpiiPCKAccuracy', 'JhmdbPCKAccuracy', 'ProposalRecall',
'PSNR', 'MAE', 'MSE', 'SSIM', 'SNR', 'MultiLabelMetric',
'AveragePrecision', 'AVAMeanAP', 'BLEU'
'AveragePrecision', 'AVAMeanAP', 'BLEU', 'InstanceSeg'
]
265 changes: 265 additions & 0 deletions mmeval/metrics/instance_seg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from copy import deepcopy
from typing import Dict, List, Sequence

from mmeval.core.base_metric import BaseMetric
from mmeval.metrics.utils import scannet_eval


def aggregate_predictions(masks, labels, scores, valid_class_ids):
"""Maps predictions to ScanNet evaluator format.

Args:
masks (list[numpy.ndarray]): Per scene predicted instance masks.
labels (list[numpy.ndarray]): Per scene predicted instance labels.
scores (list[numpy.ndarray]): Per scene predicted instance scores.
valid_class_ids (tuple[int]): Ids of valid categories.

Returns:
list[dict]: Per scene aggregated predictions.
"""
Pzzzzz5142 marked this conversation as resolved.
Show resolved Hide resolved
infos = []
for id, (mask, label, score) in enumerate(zip(masks, labels, scores)):
info = dict()
n_instances = mask.max() + 1
for i in range(n_instances):
# match pred_instance['filename'] from assign_instances_for_scan
file_name = f'{id}_{i}'
info[file_name] = dict()
info[file_name]['mask'] = (mask == i).astype(int)
info[file_name]['label_id'] = valid_class_ids[label[i]]
info[file_name]['conf'] = score[i]
infos.append(info)
return infos


def rename_gt(gt_semantic_masks, gt_instance_masks, valid_class_ids):
Pzzzzz5142 marked this conversation as resolved.
Show resolved Hide resolved
"""Maps gt instance and semantic masks to instance masks for ScanNet
evaluator.

Args:
gt_semantic_masks (list[numpy.ndarray]): Per scene gt semantic masks.
gt_instance_masks (list[numpy.ndarray]): Per scene gt instance masks.
valid_class_ids (tuple[int]): Ids of valid categories.

Returns:
list[np.array]: Per scene instance masks.
"""
renamed_instance_masks = []
for semantic_mask, instance_mask in zip(gt_semantic_masks,
gt_instance_masks):
unique = np.unique(instance_mask)
assert len(unique) < 1000
for i in unique:
semantic_instance = semantic_mask[instance_mask == i]
semantic_unique = np.unique(semantic_instance)
assert len(semantic_unique) == 1
if semantic_unique[0] < len(valid_class_ids):
instance_mask[
instance_mask ==
i] = 1000 * valid_class_ids[semantic_unique[0]] + i
renamed_instance_masks.append(instance_mask)
return renamed_instance_masks


def instance_seg_eval(gt_semantic_masks,
gt_instance_masks,
pred_instance_masks,
pred_instance_labels,
pred_instance_scores,
valid_class_ids,
class_labels,
options=None):
Pzzzzz5142 marked this conversation as resolved.
Show resolved Hide resolved
"""Instance Segmentation Evaluation.

Evaluate the result of the instance segmentation.

Args:
gt_semantic_masks (list[numpy.ndarray]): Ground truth semantic masks.
gt_instance_masks (list[numpy.ndarray]): Ground truth instance masks.
pred_instance_masks (list[numpy.ndarray]): Predicted instance masks.
pred_instance_labels (list[numpy.ndarray]): Predicted instance labels.
pred_instance_scores (list[numpy.ndarray]): Predicted instance scores.
valid_class_ids (tuple[int]): Ids of valid categories.
class_labels (tuple[str]): Names of valid categories.
options (dict, optional): Additional options. Keys may contain:
`overlaps`, `min_region_sizes`, `distance_threshes`,
`distance_confs`. Default: None.
logger (logging.Logger | str, optional): The way to print the mAP
summary. See `mmdet.utils.print_log()` for details. Default: None.

Returns:
dict[str, float]: Dict of results.
"""
assert len(valid_class_ids) == len(class_labels)
id_to_label = {
valid_class_ids[i]: class_labels[i]
for i in range(len(valid_class_ids))
}
preds = aggregate_predictions(
masks=pred_instance_masks,
labels=pred_instance_labels,
scores=pred_instance_scores,
valid_class_ids=valid_class_ids)
gts = rename_gt(gt_semantic_masks, gt_instance_masks, valid_class_ids)
metrics = scannet_eval(
preds=preds,
gts=gts,
options=options,
valid_class_ids=valid_class_ids,
class_labels=class_labels,
id_to_label=id_to_label)
return metrics


class InstanceSeg(BaseMetric):
"""3D instance segmentation evaluation metric.

Pzzzzz5142 marked this conversation as resolved.
Show resolved Hide resolved
Args:
dataset_meta (dict): Provide dataset meta information.

Example:
>>> import numpy as np
>>> from mmeval import InstanceSegMetric
>>> seg_valid_class_ids = (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24,
>>> 28, 33, 34, 36, 39)
>>> class_labels = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door',
... 'window', 'bookshelf', 'picture', 'counter', 'desk',
... 'curtain', 'refrigerator', 'showercurtrain', 'toilet',
... 'sink', 'bathtub', 'garbagebin')
>>> dataset_meta = dict(
... seg_valid_class_ids=seg_valid_class_ids, classes=class_labels)
>>>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The example is too complicated and needs to be simplified~

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is hard to simplify it for this metric need many code to generate input.

>>> def _demo_mm_model_output(self):
... n_points_list = [3300, 3000]
... gt_labels_list = [[0, 0, 0, 0, 0, 0, 14, 14, 2, 1],
... [13, 13, 2, 1, 3, 3, 0, 0, 0]]
...
... predictions = []
... groundtruths = []
...
... for n_points, gt_labels in zip(n_points_list, gt_labels_list):
... gt_instance_mask = np.ones(n_points, dtype=int) * -1
... gt_semantic_mask = np.ones(n_points, dtype=int) * -1
... for i, gt_label in enumerate(gt_labels):
... begin = i * 300
... end = begin + 300
... gt_instance_mask[begin:end] = i
... gt_semantic_mask[begin:end] = gt_label
...
... ann_info_data = dict()
... ann_info_data['pts_instance_mask'] = torch.tensor(
... gt_instance_mask)
... ann_info_data['pts_semantic_mask'] = torch.tensor(
... gt_semantic_mask)
...
... results_dict = dict()
... pred_instance_mask = np.ones(n_points, dtype=int) * -1
... labels = []
... scores = []
... for i, gt_label in enumerate(gt_labels):
... begin = i * 300
... end = begin + 300
... pred_instance_mask[begin:end] = i
... labels.append(gt_label)
... scores.append(.99)
...
... results_dict['pts_instance_mask'] = torch.tensor(
... pred_instance_mask)
... results_dict['instance_labels'] = torch.tensor(labels)
... results_dict['instance_scores'] = torch.tensor(scores)
...
... predictions.append(results_dict)
... groundtruths.append(ann_info_data)
...
... return predictions, groundtruths
>>>
>>> instance_seg_metric = InstanceSegMetric(dataset_meta=dataset_meta)
>>> res = instance_seg_metric(predictions, groundtruths)
>>> res
{
'all_ap': 1.0,
'all_ap_50%': 1.0,
'all_ap_25%': 1.0,
'classes': {
'cabinet': {
'ap': 1.0,
'ap50%': 1.0,
'ap25%': 1.0
},
'bed': {
'ap': 1.0,
'ap50%': 1.0,
'ap25%': 1.0
},
'chair': {
'ap': 1.0,
'ap50%': 1.0,
'ap25%': 1.0
},
...
}
}
Pzzzzz5142 marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
assert self.dataset_meta is not None
self.classes = self.dataset_meta['classes']
self.valid_class_ids = self.dataset_meta['seg_valid_class_ids']
Pzzzzz5142 marked this conversation as resolved.
Show resolved Hide resolved

def add(self, predictions: Sequence[Dict], groundtruths: Sequence[Dict]) -> None: # type: ignore # yapf: disable # noqa: E501
"""Process one batch of data samples and predictions.

The processed results should be stored in ``self.results``,
which will be used to compute the metrics when all batches
have been processed.
Args:
predictions (Sequence[Dict]): A sequence of dict. Each dict
representing a detection result, with the following keys:

- pts_instance_mask(numpy.ndarray): Predicted instance masks.
- instance_labels(numpy.ndarray): Predicted instance labels.
- instance_scores(numpy.ndarray): Predicted instance scores.
groundtruths (Sequence[Dict]): A sequence of dict. Each dict
represents a groundtruths for an image, with the following
keys:
- pts_instance_mask(numpy.ndarray): Ground truth instance masks.
- pts_semantic_mask(numpy.ndarray): Ground truth semantic masks.
"""
Pzzzzz5142 marked this conversation as resolved.
Show resolved Hide resolved
for prediction, groundtruth in zip(predictions, groundtruths):
self._results.append((deepcopy(prediction), deepcopy(groundtruth)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need deepcopy~

Copy link
Author

@Pzzzzz5142 Pzzzzz5142 Jan 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't use deepcopy, the metric itself will change the value of the tensor inplace which may cause confusion.


def compute_metric(self, results: List[List[Dict]]) -> Dict[str, float]:
"""Compute the metrics from processed results.

Args:
results (list): The processed results of each batch.
Returns:
Dict[str, float]: The computed metrics. The keys are the names of
the metrics, and the values are corresponding results.
Pzzzzz5142 marked this conversation as resolved.
Show resolved Hide resolved
"""
gt_semantic_masks = []
gt_instance_masks = []
pred_instance_masks = []
pred_instance_labels = []
pred_instance_scores = []

for result_pred, result_gt in results:
gt_semantic_masks.append(result_gt['pts_semantic_mask'])
gt_instance_masks.append(result_gt['pts_instance_mask'])
pred_instance_masks.append(result_pred['pts_instance_mask'])
pred_instance_labels.append(result_pred['instance_labels'])
pred_instance_scores.append(result_pred['instance_scores'])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A suggestion that simplify the code~

Suggested change
gt_semantic_masks = []
gt_instance_masks = []
pred_instance_masks = []
pred_instance_labels = []
pred_instance_scores = []
for result_pred, result_gt in results:
gt_semantic_masks.append(result_gt['pts_semantic_mask'])
gt_instance_masks.append(result_gt['pts_instance_mask'])
pred_instance_masks.append(result_pred['pts_instance_mask'])
pred_instance_labels.append(result_pred['instance_labels'])
pred_instance_scores.append(result_pred['instance_scores'])
gt_semantic_masks = [gt['pts_semantic_mask'] for _, gt in results]
gt_instance_masks = [gt['pts_instance_mask'] for _, gt in results]
pred_instance_masks = [pred['pts_instance_mask'] for pred, _ in results]
pred_instance_labels = [pred['instance_labels'] for pred, _ in results]
pred_instance_scores = [pred['instance_scores'] for pred, _ in results]

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this way, we would do 5 for loops, which may be time consuming.


ret_dict = instance_seg_eval(
gt_semantic_masks,
gt_instance_masks,
pred_instance_masks,
pred_instance_labels,
pred_instance_scores,
valid_class_ids=self.valid_class_ids,
class_labels=self.classes)

return ret_dict
4 changes: 3 additions & 1 deletion mmeval/metrics/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .bbox_overlaps import calculate_bboxes_area, calculate_overlaps
from .evaluate_semantic_instance import scannet_eval
from .image_transforms import reorder_and_crop
from .keypoint import calc_distances, distance_acc
from .polygon import (poly2shapely, poly_intersection, poly_iou,
Expand All @@ -8,5 +9,6 @@
__all__ = [
'poly2shapely', 'polys2shapely', 'poly_union', 'poly_intersection',
'poly_make_valid', 'poly_iou', 'calc_distances', 'distance_acc',
'calculate_overlaps', 'calculate_bboxes_area', 'reorder_and_crop'
'calculate_overlaps', 'calculate_bboxes_area', 'reorder_and_crop',
'scannet_eval'
]
Loading