Skip to content

Commit

Permalink
Merge pull request #105 from constantinpape/vis-metric
Browse files Browse the repository at this point in the history
Metric visualization
  • Loading branch information
constantinpape authored Oct 2, 2024
2 parents 2ad2e44 + da9d6a5 commit 12b42d3
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 14 deletions.
26 changes: 16 additions & 10 deletions elf/evaluation/matching.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional, Tuple

import numpy as np
from scipy.optimize import linear_sum_assignment
from .util import contigency_table
Expand Down Expand Up @@ -48,22 +50,26 @@ def f1(tp, fp, fn):
return (2*tp)/(2*tp+fp+fn) if tp > 0 else 0


def label_overlap(seg_a, seg_b, ignore_label=0):
""" Number of overlapping elements for objects in two label images
def label_overlap(
seg_a: np.ndarray,
seg_b: np.ndarray,
ignore_label: Optional[int] = 0,
) -> Tuple[np.ndarray, np.ndarray]:
"""Compute the number of overlapping elements for objects in two label images.
Arguments:
seg_a [np.ndarray] - candidate segmentation to evaluate
seg_b [np.ndarray] - candidate segmentation to compare to
ignore_label [int] - overlap of any objects with this label are not
Args:
seg_a: candidate segmentation to evaluate
seg_b: candidate segmentation to compare to
ignore_label: overlap of any objects with this label are not
taken into account in the output. `None` indicates that no label
should be ignored. It is assumed that the `ignore_label` has the
same meaning in both segmentations.
Returns:
np.ndarray[uint64] - each cell i,j has the count of overlapping elements
of object i in `seg_a` with obj j in `seg_b`. Note: indices in the
returned matrix do not correspond to object ids anymore.
tuple[int, int] - index of ignore label in label_overlap output matrix
Matrix with cells i,j containing the count of overlapping elements
of object i in `seg_a` with obj j in `seg_b`.
Note: indices in the returned matrix may not correspond to object ids anymore.
Index of ignore label in label_overlap output matrix.
"""
p_ids, p_counts = contigency_table(seg_a, seg_b)[2:]
p_ids = p_ids.astype("uint64")
Expand Down
23 changes: 19 additions & 4 deletions elf/evaluation/util.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,41 @@
from typing import Dict, Tuple

import numpy as np
import nifty.ground_truth as ngt


def contigency_table(seg_a, seg_b):
def contigency_table(
seg_a: np.ndarray,
seg_b: np.ndarray,
) -> Tuple[Dict[int, int], Dict[int, int], np.ndarray, np.ndarray]:
""" Compute the pairs and counts in the contingency table of seg_a and seg_b.
The contingency table counts the number of pixels that are shared between
objects from seg_a and seg_b.
Args:
seg_a: the first segmentation.
seg_b: the second segmentation.
Returns:
Dictionary that maps ids in seg_a to count.
Dictionary that maps ids in seg_b to count.
The pairs in the contigency table, giving first the id in seg_a and then the one in seg_b.
The overlap count in the contigency table.
"""
# compute the unique ids and couunts for seg a and seg b
# and wrap them in a dict
a_ids, a_counts = np.unique(seg_a, return_counts=True)
b_ids, b_counts = np.unique(seg_b, return_counts=True)
a_dict = dict(zip(a_ids, a_counts.astype('float64')))
b_dict = dict(zip(b_ids, b_counts.astype('float64')))
a_dict = dict(zip(a_ids, a_counts.astype("float64")))
b_dict = dict(zip(b_ids, b_counts.astype("float64")))

# compute the overlaps and overlap counts
# use nifty gt functionality
ovlp_comp = ngt.overlap(seg_a, seg_b)
ovlps = [ovlp_comp.overlapArrays(ida, sorted=False) for ida in a_ids]
p_ids = np.array([[ida, idb] for ida, ovlp in zip(a_ids, ovlps) for idb in ovlp[0]])
p_counts = np.concatenate([ovlp[1] for ovlp in ovlps]).astype('float64')
p_counts = np.concatenate([ovlp[1] for ovlp in ovlps]).astype("float64")
assert len(p_ids) == len(p_counts)

# this is the alternative (naive) numpy impl, unfortunately this is very slow and
Expand Down
1 change: 1 addition & 0 deletions elf/visualisation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .grid_views import simple_grid_view
from .object_visualisation import visualise_iou_scores, visualise_dice_scores, visualise_voi_scores
from .size_histogram import plot_size_histogram
from .metric_visualization import run_metric_visualization
85 changes: 85 additions & 0 deletions elf/visualisation/metric_visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import numpy as np

from skimage.segmentation import relabel_sequential
from elf.evaluation.matching import label_overlap, intersection_over_union


def _compute_matches(prediction, ground_truth, overlap_matrix, iou_threshold):
matches = overlap_matrix > iou_threshold

# Get the TP and FP ids, by checking which rows have / don't have a match.
pred_matches = np.any(matches, axis=1)
tp_ids = np.where(pred_matches)[0]
if 0 in tp_ids:
tp_ids = tp_ids[1:]
fp_ids = np.where(~pred_matches)[0]
if 0 in fp_ids:
fp_ids = fp_ids[1:]

# Get the FN ids by checking which columns don't have a match.
gt_matches = np.any(matches, axis=0)
fn_ids = np.where(~gt_matches)[0]
if 0 in fn_ids:
fn_ids = fn_ids[1:]

# Compute masks based on the ids.
tp = np.isin(prediction, tp_ids)
fp = np.isin(prediction, fp_ids)
fn = np.isin(ground_truth, fn_ids)

return tp, fp, fn


def run_metric_visualization(
image: np.ndarray,
prediction: np.ndarray,
ground_truth: np.ndarray,
):
"""Visualize the metric scores over a range of thresholds.
Args:
image: The input image
prediction: The predictions generated over the input image.
ground_truth: The true labels for the input image.
"""
import napari
from magicgui import magic_factory

ground_truth = relabel_sequential(ground_truth)[0]
prediction = relabel_sequential(prediction)[0]

# Compute the overlaps for objects in the prediction and ground-truth.
overlap_matrix = intersection_over_union(label_overlap(prediction, ground_truth, ignore_label=None)[0])

# Compute the initial TPs, FPs and FNs based on an IOU threshold of 0.5.
iou_threshold = 0.5
tp, fp, fn = _compute_matches(prediction, ground_truth, overlap_matrix, iou_threshold)

viewer = napari.Viewer()
viewer.add_image(image)
viewer.add_labels(ground_truth, name="Ground Truth")
viewer.add_labels(prediction, name="Prediction")

# The keyword changed from color->colormap with napari 0.5
try:
tp_layer = viewer.add_labels(tp, name="True Positives", color={1: "green"})
fp_layer = viewer.add_labels(fp, name="False Positives", color={1: "red"})
fn_layer = viewer.add_labels(fn, name="False Negatives", color={1: "blue"})
except TypeError:
tp_layer = viewer.add_labels(tp, name="True Positives", colormap={1: "green"})
fp_layer = viewer.add_labels(fp, name="False Positives", colormap={1: "red"})
fn_layer = viewer.add_labels(fn, name="False Negatives", colormap={1: "blue"})

@magic_factory(
call_button="Update IoU Threshold",
iou_threshold={"widget_type": "FloatSlider", "min": 0.1, "max": 1.0, "step": 0.05}
)
def update_iou_threshold(iou_threshold: float = 0.5):
new_tp, new_fp, new_fn = _compute_matches(prediction, ground_truth, overlap_matrix, iou_threshold)
tp_layer.data = new_tp
fp_layer.data = new_fp
fn_layer.data = new_fn

iou_widget = update_iou_threshold()
viewer.window.add_dock_widget(iou_widget, name="IoU Threshold Slider")
napari.run()
42 changes: 42 additions & 0 deletions example/visualization/check_metric_viewer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import imageio.v3 as imageio

from elf.visualisation import run_metric_visualization

# to simplify switching the folder
INPUT_FOLDER = "/home/anwai/data/livecell"
# INPUT_FOLDER = "/home/pape/Work/data/incu_cyte/livecell"


def _run_prediction(image_path):
# NOTE: overwrite this function to use your own prediction pipeline.
from micro_sam.automatic_segmentation import automatic_instance_segmentation
prediction = automatic_instance_segmentation(input_path=image_path, model_type="vit_b_lm")
return prediction


def check_on_livecell(input_path, gt_path):
if input_path is None and gt_path is None:
from micro_sam.evaluation.livecell import _get_livecell_paths
image_paths, gt_paths = _get_livecell_paths(input_folder=INPUT_FOLDER)
image_path, gt_path = image_paths[0], gt_paths[0]

image = imageio.imread(image_path)
ground_truth = imageio.imread(gt_path)

prediction = _run_prediction(image_path)

# Visualize metrics over the prediction and ground truth.
run_metric_visualization(image, prediction, ground_truth)


def main(args):
check_on_livecell(input_path=args.input_path, gt_path=args.gt_path)


if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input_path", type=str, default=None)
parser.add_argument("-gt", "--gt_path", type=str, default=None)
args = parser.parse_args()
main(args)

0 comments on commit 12b42d3

Please sign in to comment.