From b91af154a35a082ce2f86817eb942b535b75cce4 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Mon, 30 Sep 2024 13:47:03 +0200 Subject: [PATCH] Add scripts for metric visualization per object --- elf/visualisation/__init__.py | 1 + elf/visualisation/metric_visualization.py | 83 ++++++++++++++++++++ example/visualization/check_metric_viewer.py | 40 ++++++++++ 3 files changed, 124 insertions(+) create mode 100644 elf/visualisation/metric_visualization.py create mode 100644 example/visualization/check_metric_viewer.py diff --git a/elf/visualisation/__init__.py b/elf/visualisation/__init__.py index 7c74c7a..84e492f 100644 --- a/elf/visualisation/__init__.py +++ b/elf/visualisation/__init__.py @@ -2,3 +2,4 @@ from .grid_views import simple_grid_view from .object_visualisation import visualise_iou_scores, visualise_dice_scores, visualise_voi_scores from .size_histogram import plot_size_histogram +from .metric_visualization import run_metric_visualization diff --git a/elf/visualisation/metric_visualization.py b/elf/visualisation/metric_visualization.py new file mode 100644 index 0000000..77f6892 --- /dev/null +++ b/elf/visualisation/metric_visualization.py @@ -0,0 +1,83 @@ +import numpy as np + + +def run_metric_visualization( + image: np.ndarray, + prediction: np.ndarray, + ground_truth: np.ndarray, +): + """Visualize the metric scores over a range of thresholds. + """ + import napari + from magicgui import magic_factory + + iou_threshold = 0.5 + tp, fp, fn = _calculate_scores(ground_truth, prediction, iou_threshold) + + viewer = napari.Viewer() + viewer.add_image(image) + viewer.add_labels(ground_truth, name='Ground Truth') + viewer.add_labels(prediction, name='Prediction') + tp_layer = viewer.add_labels(tp, name='True Positives', color={1: 'green'}) + fp_layer = viewer.add_labels(fp, name='False Positives', color={1: 'red'}) + fn_layer = viewer.add_labels(fn, name='False Negatives', color={1: 'blue'}) + + @magic_factory( + call_button="Update IoU Threshold", + iou_threshold={"widget_type": "FloatSlider", "min": 0.5, "max": 1.0, "step": 0.1} + ) + def update_iou_threshold(iou_threshold=0.5): + new_tp, new_fp, new_fn = _calculate_scores(ground_truth, prediction, iou_threshold) + tp_layer.data = new_tp + fp_layer.data = new_fp + fn_layer.data = new_fn + + iou_widget = update_iou_threshold() + viewer.window.add_dock_widget(iou_widget, name='IoU Threshold Slider') + napari.run() + + +def _intersection_over_union(gt, predicton): + intersection = np.logical_and(gt, predicton).sum() + union = np.logical_or(gt, predicton).sum() + if union == 0: + return 0 + return intersection / union + + +def _calculate_scores(ground_truth, prediction, iou_threshold): + gt_ids = np.unique(ground_truth) + pred_ids = np.unique(prediction) + + ignore_index = 0 + gt_ids = gt_ids[gt_ids != ignore_index] + pred_ids = pred_ids[pred_ids != ignore_index] + + shape = ground_truth.shape + tp, fp, fn = np.zeros(shape, dtype=bool), np.zeros(shape, dtype=bool), np.zeros(shape, dtype=bool) + matched_gt, matched_pred = set(), set() + + for pred_id in pred_ids: + best_iou = 0 + best_gt_id = None + for gt_id in gt_ids: + if gt_id in matched_gt: + continue + + iou = _intersection_over_union((ground_truth == gt_id), (prediction == pred_id)) + if iou > best_iou: + best_iou = iou + best_gt_id = gt_id + + if best_iou >= iou_threshold: + tp = np.logical_or(tp, (prediction == pred_id)) + matched_gt.add(best_gt_id) + matched_pred.add(pred_id) + else: + fp = np.logical_or(fp, (prediction == pred_id)) + + for gt_id in gt_ids: + if gt_id not in matched_gt: + fn = np.logical_or(fn, (ground_truth == gt_id)) + + return tp.astype(int), fp.astype(int), fn.astype(int) diff --git a/example/visualization/check_metric_viewer.py b/example/visualization/check_metric_viewer.py new file mode 100644 index 0000000..31d8150 --- /dev/null +++ b/example/visualization/check_metric_viewer.py @@ -0,0 +1,40 @@ +from pathlib import Path + +import imageio.v3 as imageio + +from elf.visualisation import run_metric_visualization + + +def _run_prediction(image_path): + # NOTE: overwrite this function to use your own prediction pipeline. + from micro_sam.automatic_segmentation import automatic_instance_segmentation + prediction = automatic_instance_segmentation(input_path=image_path, model_type="vit_b_lm") + return prediction + + +def check_on_livecell(input_path, gt_path): + if input_path is None and gt_path is None: + from micro_sam.evaluation.livecell import _get_livecell_paths + image_paths, gt_paths = _get_livecell_paths(input_folder="/home/anwai/data/livecell") + image_path, gt_path = image_paths[0], gt_paths[0] + + image = imageio.imread(image_path) + ground_truth = imageio.imread(gt_path) + + prediction = _run_prediction(image_path) + + # Visualize metrics over the prediction and ground truth. + run_metric_visualization(image, prediction, ground_truth) + + +def main(args): + check_on_livecell(input_path=args.input_path, gt_path=args.gt_path) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--input_path", type=str, default=None) + parser.add_argument("-gt", "--gt_path", type=str, default=None) + args = parser.parse_args() + main(args)